Skip to content

Commit a87c0f3

Browse files
committed
Add native complex conversion
1 parent f0935d8 commit a87c0f3

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

asteroid_filterbanks/transforms.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,36 @@ def from_torchaudio(tensor, dim: int = -2):
324324
return torch.cat(torch.chunk(tensor, 2, dim=-1), dim=dim).squeeze(-1)
325325

326326

327+
def to_torch_complex(tensor, dim: int = -2):
328+
"""Converts complex-like torch tensor to native PyTorch complex tensor.
329+
330+
Args:
331+
tensor (torch.tensor): asteroid-style complex-like torch tensor.
332+
dim(int, optional): the frequency (or equivalent) dimension along which
333+
real and imaginary values are concatenated.
334+
335+
Returns:
336+
:class:`torch.Tensor`:
337+
Pytorch native complex-like torch tensor.
338+
"""
339+
return torch.view_as_complex(to_torchaudio(tensor, dim=dim))
340+
341+
342+
def from_torch_complex(tensor, dim: int = -2):
343+
"""Converts Pytorch native complex tensor to complex-like torch tensor.
344+
345+
Args:
346+
tensor (torch.tensor): PyTorch native complex-like torch tensor.
347+
dim(int, optional): the frequency (or equivalent) dimension along which
348+
real and imaginary values are concatenated.
349+
350+
Returns:
351+
:class:`torch.Tensor`:
352+
asteroid-style complex-like torch tensor.
353+
"""
354+
return torch.cat([tensor.real, tensor.imag], dim=dim)
355+
356+
327357
def angle(tensor, dim: int = -2):
328358
"""Return the angle of the complex-like torch tensor.
329359

tests/transforms_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,19 @@ def test_torchaudio_format(dim, max_tested_ndim):
190190
assert ta_tensor.shape[-1] == 2
191191

192192

193+
@pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2, -3])
194+
@pytest.mark.parametrize("max_tested_ndim", [4, 5])
195+
def test_torch_complex_format(dim, max_tested_ndim):
196+
# Random tensor shape
197+
tensor_shape = [random.randint(1, 10) for _ in range(max_tested_ndim)]
198+
# Make sure complex dimension has even shape
199+
tensor_shape[dim] = 2 * tensor_shape[dim]
200+
complex_tensor = torch.randn(tensor_shape)
201+
ta_tensor = transforms.to_torch_complex(complex_tensor, dim=dim)
202+
tensor_back = transforms.from_torch_complex(ta_tensor, dim=dim)
203+
assert_allclose(complex_tensor, tensor_back)
204+
205+
193206
def test_magphase():
194207
spec_shape = [2, 514, 100]
195208
spec = torch.randn(*spec_shape)

0 commit comments

Comments
 (0)