diff --git a/spyrit/core/prep.py b/spyrit/core/prep.py index 070a96fd..77d9904f 100644 --- a/spyrit/core/prep.py +++ b/spyrit/core/prep.py @@ -302,8 +302,8 @@ def sigma(self, x: torch.tensor) -> torch.tensor: :attr:`x`: batch of images in the Hadamard domain Shape: - - Input: :math:`(B,2*M)` where :math:`B` is the batch dimension - - Output: :math:`(B, M)` + - Input: :math:`(*,2*M)` :math:`*` indicates one or more dimensions + - Output: :math:`(*, M)` Example: >>> x = torch.rand([10,2*400], dtype=torch.float) @@ -312,7 +312,7 @@ def sigma(self, x: torch.tensor) -> torch.tensor: torch.Size([10, 400]) """ - x = x[:,self.even_index] + x[:,self.odd_index] + x = x[...,self.even_index] + x[...,self.odd_index] x = 4*x/(self.alpha**2) # Cov is in [-1,1] so *4 return x