From c0a672fdee21fcf1b049cbf2e6a2427b694c4b40 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 30 Jan 2025 14:21:51 +0900 Subject: [PATCH 1/6] support mixed phase cepstrum --- diffsptk/modules/c2mpir.py | 2 +- diffsptk/modules/mglsadf.py | 268 +++++++++++++++++++++++++++--------- tests/test_mglsadf.py | 118 ++++++++++++++-- 3 files changed, 306 insertions(+), 82 deletions(-) diff --git a/diffsptk/modules/c2mpir.py b/diffsptk/modules/c2mpir.py index bee2da8a..e4854889 100644 --- a/diffsptk/modules/c2mpir.py +++ b/diffsptk/modules/c2mpir.py @@ -76,7 +76,7 @@ def forward(self, c): @staticmethod def _forward(c, ir_length, n_fft): C = torch.fft.fft(c, n=n_fft) - h = torch.fft.ifft(cexp(C))[..., :ir_length].real + h = torch.fft.ifft(cexp(C)).real[..., :ir_length] return h _func = _forward diff --git a/diffsptk/modules/mglsadf.py b/diffsptk/modules/mglsadf.py index 601e6a02..ea9ede19 100644 --- a/diffsptk/modules/mglsadf.py +++ b/diffsptk/modules/mglsadf.py @@ -15,10 +15,12 @@ # ------------------------------------------------------------------------ # import torch +import torch.nn.functional as F from torch import nn from ..misc.utils import Lambda, check_size, get_gamma, remove_gain from .b2mc import MLSADigitalFilterCoefficientsToMelCepstrum +from .c2mpir import CepstrumToMinimumPhaseImpulseResponse from .gnorm import GeneralizedCepstrumGainNormalization from .istft import InverseShortTimeFourierTransform from .linear_intpl import LinearInterpolation @@ -28,6 +30,23 @@ from .stft import ShortTimeFourierTransform +def is_array_like(x): + """Return True if the input is array-like. + + Parameters + ---------- + x : object + Any object. + + Returns + ------- + out : bool + True if the input is array-like. + + """ + return isinstance(x, (tuple, list)) + + def mirror(x, half=False): """Mirror the input tensor. @@ -58,8 +77,9 @@ class PseudoMGLSADigitalFilter(nn.Module): Parameters ---------- - filter_order : int >= 0 - Order of filter coefficients, :math:`M`. + filter_order : int >= 0 or tuple[int, int] + Order of filter coefficients, :math:`M` or :math:`(M, N)`. A tuple input is + allowed only if **phase** is 'mixed'. frame_period : int >= 1 Frame period, :math:`P`. @@ -76,7 +96,7 @@ class PseudoMGLSADigitalFilter(nn.Module): ignore_gain : bool If True, perform filtering without gain. - phase : ['minimum', 'maximum', 'zero'] + phase : ['minimum', 'maximum', 'zero', 'mixed'] Filter type. mode : ['multi-stage', 'single-stage', 'freq-domain'] @@ -92,10 +112,10 @@ class PseudoMGLSADigitalFilter(nn.Module): taylor_order : int >= 0 Order of Taylor series expansion (valid only if **mode** is 'multi-stage'). - cep_order : int >= 0 + cep_order : int >= 0 or tuple[int, int] Order of linear cepstrum (valid only if **mode** is 'multi-stage'). - ir_length : int >= 1 + ir_length : int >= 1 or tuple[int, int] Length of impulse response (valid only if **mode** is 'single-stage'). **kwargs : additional keyword arguments @@ -124,11 +144,18 @@ def __init__( ): super().__init__() - self.filter_order = filter_order self.frame_period = frame_period + # Format parameters. + if phase == "mixed" and not is_array_like(filter_order): + filter_order = (filter_order, filter_order) gamma = get_gamma(gamma, c) + if phase == "mixed": + self.split_sections = (filter_order[0] + 1, filter_order[1]) + else: + self.split_sections = (filter_order + 1,) + if mode == "multi-stage": self.mglsadf = MultiStageFIRFilter( filter_order, @@ -170,8 +197,11 @@ def forward(self, x, mc): x : Tensor [shape=(..., T)] Excitation signal. - mc : Tensor [shape=(..., T/P, M+1)] - Mel-generalized cepstrum, not MLSA digital filter coefficients. + mc : Tensor [shape=(..., T/P, M+1)] or [shape=(..., T/P, M+1+N)] + Mel-generalized cepstrum, not MLSA digital filter coefficients. Note that + the mixed-phase case assumes that the coefficients are of the form + c_{-M}, ..., c_{0}, ..., c_{N}, where M is the order of the minimum-phase + part and N is the order of the maximum-phase part. Returns ------- @@ -192,9 +222,13 @@ def forward(self, x, mc): tensor([[0.4011, 0.8760, 3.5677, 4.8725]]) """ - check_size(mc.size(-1), self.filter_order + 1, "dimension of mel-cepstrum") + check_size(mc.size(-1), sum(self.split_sections), "dimension of mel-cepstrum") check_size(x.size(-1), mc.size(-2) * self.frame_period, "sequence length") - + if len(self.split_sections) != 1: + mc_min, mc_max = torch.split(mc, self.split_sections, dim=-1) + mc_min = mc_min.flip(-1) + mc_max = F.pad(mc_max, (1, 0)) + mc = (mc_min, mc_max) # (c0, c-1, ..., c-M), (0, c1, ..., cN) y = self.mglsadf(x, mc) return y @@ -224,36 +258,64 @@ def __init__( if alpha == 0 and gamma == 0: cep_order = filter_order + # Prepare padding module. if self.phase == "minimum": - self.pad = nn.ConstantPad1d((cep_order, 0), 0) + padding = (cep_order, 0) elif self.phase == "maximum": - self.pad = nn.ConstantPad1d((0, cep_order), 0) + padding = (0, cep_order) elif self.phase == "zero": - self.pad = nn.ConstantPad1d((cep_order, cep_order), 0) + padding = (cep_order, cep_order) + elif self.phase == "mixed": + padding = cep_order if is_array_like(cep_order) else (cep_order, cep_order) else: raise ValueError(f"phase {phase} is not supported.") + self.pad = nn.ConstantPad1d(padding, 0) + + # Prepare frequency transformation module. + if self.phase == "mixed": + self.mgc2c = nn.ModuleList() + for i in range(2): + self.mgc2c.append( + MelGeneralizedCepstrumToMelGeneralizedCepstrum( + filter_order[i], + padding[i], + in_alpha=alpha, + in_gamma=gamma, + n_fft=n_fft, + ) + ) + else: + self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum( + filter_order, + cep_order, + in_alpha=alpha, + in_gamma=gamma, + n_fft=n_fft, + ) - self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum( - filter_order, - cep_order, - in_alpha=alpha, - in_gamma=gamma, - n_fft=n_fft, - ) self.linear_intpl = LinearInterpolation(frame_period) def forward(self, x, mc): - c = self.mgc2c(mc) - c0, c = remove_gain(c, value=0, return_gain=True) - - if self.phase == "minimum": - c = c.flip(-1) - elif self.phase == "maximum": - pass - elif self.phase == "zero": - c = mirror(c, half=True) + if self.phase == "mixed": + mc_min, mc_max = mc + c_min = self.mgc2c[0](mc_min) + c_max = self.mgc2c[1](mc_max) + c0 = c_min[..., :1] + c_max[..., :1] + c1_min = c_min[..., 1:].flip(-1) + c0_dummy = torch.zeros_like(c0) + c1_max = c_max[..., 1:] + c = torch.cat([c1_min, c0_dummy, c1_max], dim=-1) else: - raise RuntimeError + c = self.mgc2c(mc) + c0, c = remove_gain(c, value=0, return_gain=True) + if self.phase == "minimum": + c = c.flip(-1) + elif self.phase == "maximum": + pass + elif self.phase == "zero": + c = mirror(c, half=True) + else: + raise RuntimeError c = self.linear_intpl(c) @@ -287,18 +349,28 @@ def __init__( self.ignore_gain = ignore_gain self.phase = phase + self.n_fft = n_fft + # Prepare padding module. taps = ir_length - 1 if self.phase == "minimum": - self.pad = nn.ConstantPad1d((taps, 0), 0) + padding = (taps, 0) elif self.phase == "maximum": - self.pad = nn.ConstantPad1d((0, taps), 0) + padding = (0, taps) elif self.phase == "zero": - self.pad = nn.ConstantPad1d((taps, taps), 0) + padding = (taps, taps) + elif self.phase == "mixed": + padding = ( + (ir_length[0] - 1, ir_length[1] - 1) + if is_array_like(ir_length) + else (taps, taps) + ) else: raise ValueError(f"phase {phase} is not supported.") + self.pad = nn.ConstantPad1d(padding, 0) + self.padding = padding - if self.phase in ["minimum", "maximum"]: + if self.phase in ("minimum", "maximum"): self.mgc2ir = MelGeneralizedCepstrumToMelGeneralizedCepstrum( filter_order, ir_length - 1, @@ -308,7 +380,7 @@ def __init__( out_mul=True, n_fft=n_fft, ) - else: + elif self.phase == "zero": self.mgc2c = MelGeneralizedCepstrumToMelGeneralizedCepstrum( filter_order, ir_length - 1, @@ -320,24 +392,52 @@ def __init__( Lambda(lambda x: torch.fft.hfft(x, n=n_fft)), Lambda(lambda x: torch.fft.ifft(torch.exp(x)).real[..., :ir_length]), ) + elif self.phase == "mixed": + self.mgc2c = nn.ModuleList() + for i in range(2): + self.mgc2c.append( + MelGeneralizedCepstrumToMelGeneralizedCepstrum( + filter_order[i], + padding[i], + in_alpha=alpha, + in_gamma=gamma, + n_fft=n_fft, + ) + ) + self.c2ir = CepstrumToMinimumPhaseImpulseResponse( + n_fft - 1, n_fft, n_fft=n_fft + ) + else: + raise ValueError(f"phase {phase} is not supported.") + self.linear_intpl = LinearInterpolation(frame_period) def forward(self, x, mc): - if self.phase == "zero": + if self.phase == "minimum": + h = self.mgc2ir(mc) + h = h.flip(-1) + elif self.phase == "maximum": + h = self.mgc2ir(mc) + elif self.phase == "zero": c = self.mgc2c(mc) c[..., 1:] *= 0.5 if self.ignore_gain: c = remove_gain(c, value=0) h = self.c2ir(c) - else: - h = self.mgc2ir(mc) - - if self.phase == "minimum": - h = h.flip(-1) - elif self.phase == "maximum": - pass - elif self.phase == "zero": h = mirror(h) + elif self.phase == "mixed": + mc_min, mc_max = mc + c_min = self.mgc2c[0](mc_min) + c_max = self.mgc2c[1](mc_max) + if self.ignore_gain: + c0 = torch.zeros_like(c_min[..., :1]) + else: + c0 = c_min[..., :1] + c_max[..., :1] + c = torch.cat([c_min[..., 1:].flip(-1), c0, c_max[..., 1:]], dim=-1) + c = F.pad(c, (0, self.n_fft - c.size(-1))) + c = torch.roll(c, -self.padding[0], dims=-1) + h = self.c2ir(c) + h = torch.roll(h, self.padding[0], dims=-1)[..., : sum(self.padding) + 1] else: raise RuntimeError @@ -348,10 +448,6 @@ def forward(self, x, mc): h = h / h[..., -1:] elif self.phase == "maximum": h = h / h[..., :1] - elif self.phase == "zero": - pass - else: - raise RuntimeError x = self.pad(x) x = x.unfold(-1, h.size(-1), 1) @@ -379,14 +475,42 @@ def __init__( assert 2 * frame_period < frame_length self.ignore_gain = ignore_gain + self.phase = phase if self.ignore_gain: - self.gnorm = GeneralizedCepstrumGainNormalization(filter_order, gamma=gamma) - self.mc2b = MelCepstrumToMLSADigitalFilterCoefficients( - filter_order, alpha=alpha - ) - self.b2mc = MLSADigitalFilterCoefficientsToMelCepstrum( - filter_order, alpha=alpha + self.gnorm = nn.ModuleList() + self.mc2b = nn.ModuleList() + self.b2mc = nn.ModuleList() + self.mgc2sp = nn.ModuleList() + + if not is_array_like(filter_order): + filter_order = (filter_order, filter_order) + + n = 2 if phase == "mixed" else 1 + for i in range(n): + if self.ignore_gain: + self.gnorm.append( + GeneralizedCepstrumGainNormalization(filter_order[i], gamma=gamma) + ) + self.mc2b.append( + MelCepstrumToMLSADigitalFilterCoefficients( + filter_order[i], alpha=alpha + ) + ) + self.b2mc.append( + MLSADigitalFilterCoefficientsToMelCepstrum( + filter_order[i], alpha=alpha + ) + ) + self.mgc2sp.append( + MelGeneralizedCepstrumToSpectrum( + filter_order[i], + fft_length, + alpha=alpha, + gamma=gamma, + out_format="complex", + n_fft=n_fft, + ) ) self.stft = ShortTimeFourierTransform( @@ -395,23 +519,31 @@ def __init__( self.istft = InverseShortTimeFourierTransform( frame_length, frame_period, fft_length, **stft_kwargs ) - self.mgc2sp = MelGeneralizedCepstrumToSpectrum( - filter_order, - fft_length, - alpha=alpha, - gamma=gamma, - out_format="magnitude" if phase == "zero" else "complex", - n_fft=n_fft, - ) def forward(self, x, mc): - if self.ignore_gain: - b = self.mc2b(mc) - b = self.gnorm(b) - b[..., 0] = 0 - mc = self.b2mc(b) + if torch.is_tensor(mc): + mc = [mc] + + Hs = [] + for i, c in enumerate(mc): + if self.ignore_gain: + b = self.mc2b[i](c) + b = self.gnorm[i](b) + b[..., 0] = 0 + c = self.b2mc[i](b) + Hs.append(self.mgc2sp[i](c)) + + if self.phase == "minimum": + H = Hs[0] + elif self.phase == "maximum": + H = Hs[0].conj() + elif self.phase == "zero": + H = Hs[0].abs() + elif self.phase == "mixed": + H = Hs[0] * Hs[1].conj() + else: + raise RuntimeError - H = self.mgc2sp(mc) X = self.stft(x) Y = H * X y = self.istft(Y, out_length=x.size(-1)) diff --git a/tests/test_mglsadf.py b/tests/test_mglsadf.py index 434cf8a9..5132158c 100644 --- a/tests/test_mglsadf.py +++ b/tests/test_mglsadf.py @@ -18,6 +18,7 @@ import numpy as np import pytest +import torch import diffsptk import tests.utils as U @@ -71,19 +72,110 @@ def test_compatibility( ) -@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("phase", ["zero", "maximum"]) @pytest.mark.parametrize("ignore_gain", [False, True]) -@pytest.mark.parametrize("phase", ["minimum", "maximum", "zero"]) -@pytest.mark.parametrize("mode", ["multi-stage", "single-stage", "freq-domain"]) -def test_differentiable(device, ignore_gain, phase, mode, B=4, T=20, P=2, M=4): - if mode == "multi-stage": - params = {"cep_order": 10} - elif mode == "single-stage": - params = {"ir_length": 20, "n_fft": 32} - elif mode == "freq-domain": - params = {"frame_length": 6, "fft_length": 16} +def test_zero_and_maximum_phase( + phase, + ignore_gain, + alpha=0.42, + M=24, + P=80, + L=400, + fft_length=512, + B=2, +): + T = os.path.getsize("tools/SPTK/asset/data.short") // 2 + cmd_x = f"nrand -l {T}" + x = torch.from_numpy(U.call(cmd_x)) - mglsadf = diffsptk.MLSA( - M, P, ignore_gain=ignore_gain, phase=phase, mode=mode, **params + cmd_mc = ( + f"x2x +sd tools/SPTK/asset/data.short | " + f"frame -p {P} -l {L} | " + f"window -w 1 -n 1 -l {L} -L {fft_length} | " + f"mgcep -a {alpha} -m {M} -l {fft_length} -E -60" + ) + mc = torch.from_numpy(U.call(cmd_mc).reshape(-1, M + 1)) + + params1 = {"mode": "multi-stage", "cep_order": 200} + mglsadf1 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase=phase, **params1 + ) + y1 = mglsadf1(x, mc).cpu().numpy() + + params2 = {"mode": "single-stage", "ir_length": 200, "n_fft": 512} + mglsadf2 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase=phase, **params2 ) - U.check_differentiability(device, mglsadf, [(B, T), (B, T // P, M + 1)]) + y2 = mglsadf2(x, mc).cpu().numpy() + assert np.corrcoef(y1, y2)[0, 1] > 0.99 + + params3 = {"mode": "freq-domain", "frame_length": L, "fft_length": fft_length} + mglsadf3 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase=phase, **params3 + ) + y3 = mglsadf3(x, mc).cpu().numpy() + assert np.corrcoef(y1, y3)[0, 1] > 0.98 + + device = "cpu" + S = T // 10 + U.check_differentiability(device, mglsadf1, [(B, S), (B, S // P, M + 1)]) + U.check_differentiability(device, mglsadf2, [(B, S), (B, S // P, M + 1)]) + U.check_differentiability(device, mglsadf3, [(B, S), (B, S // P, M + 1)]) + + +@pytest.mark.parametrize("ignore_gain", [False, True]) +def test_mixed_phase( + ignore_gain, + alpha=0.42, + M=24, + P=80, + L=400, + fft_length=512, + B=2, +): + T = os.path.getsize("tools/SPTK/asset/data.short") // 2 + cmd_x = f"nrand -l {T}" + x = torch.from_numpy(U.call(cmd_x)) + + cmd_mc = ( + f"x2x +sd tools/SPTK/asset/data.short | " + f"frame -p {P} -l {L} | " + f"window -w 1 -n 1 -l {L} -L {fft_length} | " + f"mgcep -a {alpha} -m {M} -l {fft_length} -E -60" + ) + mc = torch.from_numpy(U.call(cmd_mc).reshape(-1, M + 1)) + half_mc = mc[..., 1:] * 0.5 + mc_mix = torch.cat([half_mc.flip(-1), mc[..., :1], half_mc], dim=-1) + + params0 = {"mode": "multi-stage", "cep_order": 200} + mglsadf0 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase="zero", **params0 + ) + y0 = mglsadf0(x, mc).cpu().numpy() + + params1 = params0 + mglsadf1 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase="mixed", **params1 + ) + y1 = mglsadf1(x, mc_mix).cpu().numpy() + assert U.allclose(y0, y1) + + params2 = {"mode": "single-stage", "ir_length": 200, "n_fft": 512} + mglsadf2 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase="mixed", **params2 + ) + y2 = mglsadf2(x, mc_mix).cpu().numpy() + assert np.corrcoef(y1, y2)[0, 1] > 0.99 + + params3 = {"mode": "freq-domain", "frame_length": L, "fft_length": fft_length} + mglsadf3 = diffsptk.MLSA( + M, P, ignore_gain=ignore_gain, alpha=alpha, phase="mixed", **params3 + ) + y3 = mglsadf3(x, mc_mix).cpu().numpy() + assert np.corrcoef(y1, y3)[0, 1] > 0.99 + + device = "cpu" + S = T // 10 + U.check_differentiability(device, mglsadf1, [(B, S), (B, S // P, 2 * M + 1)]) + U.check_differentiability(device, mglsadf2, [(B, S), (B, S // P, 2 * M + 1)]) + U.check_differentiability(device, mglsadf3, [(B, S), (B, S // P, 2 * M + 1)]) From f0acad36917027645fd85a11abdd4182d480b0a3 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Fri, 31 Jan 2025 19:04:22 +0900 Subject: [PATCH 2/6] swap max/min part --- diffsptk/modules/mglsadf.py | 15 +++++++-------- tests/test_mglsadf.py | 13 +++++++++---- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/diffsptk/modules/mglsadf.py b/diffsptk/modules/mglsadf.py index ea9ede19..a2a794d4 100644 --- a/diffsptk/modules/mglsadf.py +++ b/diffsptk/modules/mglsadf.py @@ -78,7 +78,7 @@ class PseudoMGLSADigitalFilter(nn.Module): Parameters ---------- filter_order : int >= 0 or tuple[int, int] - Order of filter coefficients, :math:`M` or :math:`(M, N)`. A tuple input is + Order of filter coefficients, :math:`M` or :math:`(N, M)`. A tuple input is allowed only if **phase** is 'mixed'. frame_period : int >= 1 @@ -152,7 +152,7 @@ def __init__( gamma = get_gamma(gamma, c) if phase == "mixed": - self.split_sections = (filter_order[0] + 1, filter_order[1]) + self.split_sections = (filter_order[0], filter_order[1] + 1) else: self.split_sections = (filter_order + 1,) @@ -197,10 +197,10 @@ def forward(self, x, mc): x : Tensor [shape=(..., T)] Excitation signal. - mc : Tensor [shape=(..., T/P, M+1)] or [shape=(..., T/P, M+1+N)] + mc : Tensor [shape=(..., T/P, M+1)] or [shape=(..., T/P, N+M+1)] Mel-generalized cepstrum, not MLSA digital filter coefficients. Note that the mixed-phase case assumes that the coefficients are of the form - c_{-M}, ..., c_{0}, ..., c_{N}, where M is the order of the minimum-phase + c_{-N}, ..., c_{0}, ..., c_{M}, where M is the order of the minimum-phase part and N is the order of the maximum-phase part. Returns @@ -225,10 +225,9 @@ def forward(self, x, mc): check_size(mc.size(-1), sum(self.split_sections), "dimension of mel-cepstrum") check_size(x.size(-1), mc.size(-2) * self.frame_period, "sequence length") if len(self.split_sections) != 1: - mc_min, mc_max = torch.split(mc, self.split_sections, dim=-1) - mc_min = mc_min.flip(-1) - mc_max = F.pad(mc_max, (1, 0)) - mc = (mc_min, mc_max) # (c0, c-1, ..., c-M), (0, c1, ..., cN) + mc_max, mc_min = torch.split(mc, self.split_sections, dim=-1) + mc_max = F.pad(mc_max.flip(-1), (1, 0)) + mc = (mc_min, mc_max) # (c0, c1, ..., cM), (0, c-1, ..., c-N) y = self.mglsadf(x, mc) return y diff --git a/tests/test_mglsadf.py b/tests/test_mglsadf.py index 5132158c..ab43dee4 100644 --- a/tests/test_mglsadf.py +++ b/tests/test_mglsadf.py @@ -123,8 +123,10 @@ def test_zero_and_maximum_phase( U.check_differentiability(device, mglsadf3, [(B, S), (B, S // P, M + 1)]) +@pytest.mark.parametrize("phase", ["zero", "maximum"]) @pytest.mark.parametrize("ignore_gain", [False, True]) def test_mixed_phase( + phase, ignore_gain, alpha=0.42, M=24, @@ -144,12 +146,15 @@ def test_mixed_phase( f"mgcep -a {alpha} -m {M} -l {fft_length} -E -60" ) mc = torch.from_numpy(U.call(cmd_mc).reshape(-1, M + 1)) - half_mc = mc[..., 1:] * 0.5 - mc_mix = torch.cat([half_mc.flip(-1), mc[..., :1], half_mc], dim=-1) + if phase == "zero": + half_mc = mc[..., 1:] * 0.5 + mc_mix = torch.cat([half_mc.flip(-1), mc[..., :1], half_mc], dim=-1) + elif phase == "maximum": + mc_mix = torch.cat([mc.flip(-1), 0 * mc[..., 1:]], dim=-1) params0 = {"mode": "multi-stage", "cep_order": 200} mglsadf0 = diffsptk.MLSA( - M, P, ignore_gain=ignore_gain, alpha=alpha, phase="zero", **params0 + M, P, ignore_gain=ignore_gain, alpha=alpha, phase=phase, **params0 ) y0 = mglsadf0(x, mc).cpu().numpy() @@ -172,7 +177,7 @@ def test_mixed_phase( M, P, ignore_gain=ignore_gain, alpha=alpha, phase="mixed", **params3 ) y3 = mglsadf3(x, mc_mix).cpu().numpy() - assert np.corrcoef(y1, y3)[0, 1] > 0.99 + assert np.corrcoef(y1, y3)[0, 1] > 0.98 device = "cpu" S = T // 10 From 1f918a8fff60a3b45e2b18771346171d4f53e838 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Fri, 31 Jan 2025 19:44:51 +0900 Subject: [PATCH 3/6] fix consistency --- diffsptk/modules/imglsadf.py | 7 +++++-- diffsptk/modules/mglsadf.py | 19 +++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/diffsptk/modules/imglsadf.py b/diffsptk/modules/imglsadf.py index b62cf052..5d94a163 100644 --- a/diffsptk/modules/imglsadf.py +++ b/diffsptk/modules/imglsadf.py @@ -35,8 +35,11 @@ def forward(self, y, mc): y : Tensor [shape=(..., T)] Audio signal. - mc : Tensor [shape=(..., T/P, M+1)] - Mel-generalized cepstrum, not MLSA digital filter coefficients. + mc : Tensor [shape=(..., T/P, M+1)] or [shape=(..., T/P, N+M+1)] + Mel-generalized cepstrum, not MLSA digital filter coefficients. Note that + the mixed-phase case assumes that the coefficients are of the form + c_{-N}, ..., c_{0}, ..., c_{M}, where M is the order of the minimum-phase + part and N is the order of the maximum-phase part. Returns ------- diff --git a/diffsptk/modules/mglsadf.py b/diffsptk/modules/mglsadf.py index a2a794d4..f098ba22 100644 --- a/diffsptk/modules/mglsadf.py +++ b/diffsptk/modules/mglsadf.py @@ -156,15 +156,26 @@ def __init__( else: self.split_sections = (filter_order + 1,) + def flip(x): + if is_array_like(x): + return x[1], x[0] + return x + + flip_keys = ("cep_order", "ir_length") + modified_kwargs = kwargs.copy() + for key in flip_keys: + if key in kwargs: + modified_kwargs[key] = flip(kwargs[key]) + if mode == "multi-stage": self.mglsadf = MultiStageFIRFilter( - filter_order, + flip(filter_order), frame_period, alpha=alpha, gamma=gamma, ignore_gain=ignore_gain, phase=phase, - **kwargs, + **modified_kwargs, ) elif mode == "single-stage": self.mglsadf = SingleStageFIRFilter( @@ -174,7 +185,7 @@ def __init__( gamma=gamma, ignore_gain=ignore_gain, phase=phase, - **kwargs, + **modified_kwargs, ) elif mode == "freq-domain": self.mglsadf = FrequencyDomainFIRFilter( @@ -184,7 +195,7 @@ def __init__( gamma=gamma, ignore_gain=ignore_gain, phase=phase, - **kwargs, + **modified_kwargs, ) else: raise ValueError(f"mode {mode} is not supported.") From d8986bf63eb04cdf2b5177bb45f0f2b1f58319b8 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Tue, 4 Feb 2025 11:02:08 +0900 Subject: [PATCH 4/6] bump version --- README.md | 2 +- diffsptk/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 91fceba1..f4d5d903 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ *diffsptk* is a differentiable version of [SPTK](https://github.com/sp-nitech/SPTK) based on the PyTorch framework. [![Latest Manual](https://img.shields.io/badge/docs-latest-blue.svg)](https://sp-nitech.github.io/diffsptk/latest/) -[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/2.3.0/) +[![Stable Manual](https://img.shields.io/badge/docs-stable-blue.svg)](https://sp-nitech.github.io/diffsptk/2.4.0/) [![Downloads](https://static.pepy.tech/badge/diffsptk)](https://pepy.tech/project/diffsptk) [![Python Version](https://img.shields.io/pypi/pyversions/diffsptk.svg)](https://pypi.python.org/pypi/diffsptk) [![PyTorch Version](https://img.shields.io/badge/pytorch-2.0.0%20%7C%202.5.1-orange.svg)](https://pypi.python.org/pypi/diffsptk) diff --git a/diffsptk/version.py b/diffsptk/version.py index 55e47090..3d67cd6b 100644 --- a/diffsptk/version.py +++ b/diffsptk/version.py @@ -1 +1 @@ -__version__ = "2.3.0" +__version__ = "2.4.0" From bc539ac798a0f7c3ec5a14a8637071bb45ee8fb9 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Tue, 4 Feb 2025 18:18:36 +0900 Subject: [PATCH 5/6] bug fix --- diffsptk/modules/mglsadf.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/diffsptk/modules/mglsadf.py b/diffsptk/modules/mglsadf.py index f098ba22..12105935 100644 --- a/diffsptk/modules/mglsadf.py +++ b/diffsptk/modules/mglsadf.py @@ -164,12 +164,13 @@ def flip(x): flip_keys = ("cep_order", "ir_length") modified_kwargs = kwargs.copy() for key in flip_keys: - if key in kwargs: + if key in kwargs: modified_kwargs[key] = flip(kwargs[key]) + flipped_filter_order = flip(filter_order) if mode == "multi-stage": self.mglsadf = MultiStageFIRFilter( - flip(filter_order), + flipped_filter_order, frame_period, alpha=alpha, gamma=gamma, @@ -179,7 +180,7 @@ def flip(x): ) elif mode == "single-stage": self.mglsadf = SingleStageFIRFilter( - filter_order, + flipped_filter_order, frame_period, alpha=alpha, gamma=gamma, @@ -189,7 +190,7 @@ def flip(x): ) elif mode == "freq-domain": self.mglsadf = FrequencyDomainFIRFilter( - filter_order, + flipped_filter_order, frame_period, alpha=alpha, gamma=gamma, From 3a38bf578e596b8d3c693e11123c908d4936abac Mon Sep 17 00:00:00 2001 From: takenori-y Date: Tue, 4 Feb 2025 18:44:18 +0900 Subject: [PATCH 6/6] fix test --- tests/test_mglsadf.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_mglsadf.py b/tests/test_mglsadf.py index ab43dee4..a1215fbb 100644 --- a/tests/test_mglsadf.py +++ b/tests/test_mglsadf.py @@ -29,7 +29,16 @@ @pytest.mark.parametrize("mode", ["multi-stage", "single-stage", "freq-domain"]) @pytest.mark.parametrize("c", [0, 10]) def test_compatibility( - device, ignore_gain, mode, c, alpha=0.42, M=24, P=80, L=400, fft_length=512 + device, + ignore_gain, + mode, + c, + alpha=0.42, + M=24, + P=80, + L=400, + fft_length=512, + B=2, ): if mode == "multi-stage": params = {"taylor_order": 7, "cep_order": 100} @@ -71,12 +80,16 @@ def test_compatibility( eq=lambda a, b: np.corrcoef(a, b)[0, 1] > 0.98, ) + S = T // 10 + U.check_differentiability(device, mglsadf, [(B, S), (B, S // P, M + 1)]) + @pytest.mark.parametrize("phase", ["zero", "maximum"]) @pytest.mark.parametrize("ignore_gain", [False, True]) def test_zero_and_maximum_phase( phase, ignore_gain, + device="cpu", alpha=0.42, M=24, P=80, @@ -116,7 +129,6 @@ def test_zero_and_maximum_phase( y3 = mglsadf3(x, mc).cpu().numpy() assert np.corrcoef(y1, y3)[0, 1] > 0.98 - device = "cpu" S = T // 10 U.check_differentiability(device, mglsadf1, [(B, S), (B, S // P, M + 1)]) U.check_differentiability(device, mglsadf2, [(B, S), (B, S // P, M + 1)]) @@ -128,6 +140,7 @@ def test_zero_and_maximum_phase( def test_mixed_phase( phase, ignore_gain, + device="cpu", alpha=0.42, M=24, P=80, @@ -179,7 +192,6 @@ def test_mixed_phase( y3 = mglsadf3(x, mc_mix).cpu().numpy() assert np.corrcoef(y1, y3)[0, 1] > 0.98 - device = "cpu" S = T // 10 U.check_differentiability(device, mglsadf1, [(B, S), (B, S // P, 2 * M + 1)]) U.check_differentiability(device, mglsadf2, [(B, S), (B, S // P, 2 * M + 1)])