diff --git a/demucs/hdemucs.py b/demucs/hdemucs.py index 9d2b0d24..711d4715 100644 --- a/demucs/hdemucs.py +++ b/demucs/hdemucs.py @@ -691,7 +691,7 @@ def forward(self, mix): length = x.shape[-1] z = self._spec(mix) - mag = self._magnitude(z) + mag = self._magnitude(z).to(mix.device) x = mag B, C, Fq, T = x.shape @@ -772,9 +772,21 @@ def forward(self, mix): x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + zout = self._mask(z, x) x = self._ispec(zout, length) + # back to mps device + if x_is_mps: + x = x.to('mps') + if self.hybrid: xt = xt.view(B, S, -1, length) xt = xt * stdt[:, None] + meant[:, None] diff --git a/demucs/htdemucs.py b/demucs/htdemucs.py index 2de541c7..5d2eaaa1 100644 --- a/demucs/htdemucs.py +++ b/demucs/htdemucs.py @@ -536,7 +536,7 @@ def forward(self, mix): length_pre_pad = mix.shape[-1] mix = F.pad(mix, (0, training_length - length_pre_pad)) z = self._spec(mix) - mag = self._magnitude(z) + mag = self._magnitude(z).to(mix.device) x = mag B, C, Fq, T = x.shape @@ -625,6 +625,14 @@ def forward(self, mix): x = x.view(B, S, -1, Fq, T) x = x * std[:, None] + mean[:, None] + # to cpu as mps doesnt support complex numbers + # demucs issue #435 ##432 + # NOTE: in this case z already is on cpu + # TODO: remove this when mps supports complex numbers + x_is_mps = x.device.type == "mps" + if x_is_mps: + x = x.cpu() + zout = self._mask(z, x) if self.use_train_segment: if self.training: @@ -634,6 +642,10 @@ def forward(self, mix): else: x = self._ispec(zout, length) + # back to mps device + if x_is_mps: + x = x.to("mps") + if self.use_train_segment: if self.training: xt = xt.view(B, S, -1, length) diff --git a/demucs/spec.py b/demucs/spec.py index f4aa10f6..29250459 100644 --- a/demucs/spec.py +++ b/demucs/spec.py @@ -11,6 +11,9 @@ def spectro(x, n_fft=512, hop_length=None, pad=0): *other, length = x.shape x = x.reshape(-1, length) + is_mps = x.device.type == 'mps' + if is_mps: + x = x.cpu() z = th.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, @@ -29,6 +32,9 @@ def ispectro(z, hop_length=None, length=None, pad=0): n_fft = 2 * freqs - 2 z = z.view(-1, freqs, frames) win_length = n_fft // (1 + pad) + is_mps = z.device.type == 'mps' + if is_mps: + z = z.cpu() x = th.istft(z, n_fft, hop_length, diff --git a/docs/release.md b/docs/release.md index 2ec02bcb..fe33c1d5 100644 --- a/docs/release.md +++ b/docs/release.md @@ -11,6 +11,8 @@ Made diffq an optional dependency, with an error message if not installed. Added output format flac (Free Lossless Audio Codec) +Will use CPU for complex numbers, when using MPS device (all other computations are performed by mps). + Optimize codes to save memory ## V4.0.0, 7th of December 2022