Skip to content

Commit

Permalink
update sigma for inputs of shape (B, 2M) -> (*, 2M)
Browse files Browse the repository at this point in the history
  • Loading branch information
nducros committed Jun 6, 2023
1 parent e05a1e1 commit ef80388
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions spyrit/core/prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def sigma(self, x: torch.tensor) -> torch.tensor:
:attr:`x`: batch of images in the Hadamard domain
Shape:
- Input: :math:`(B,2*M)` where :math:`B` is the batch dimension
- Output: :math:`(B, M)`
- Input: :math:`(*,2*M)` :math:`*` indicates one or more dimensions
- Output: :math:`(*, M)`
Example:
>>> x = torch.rand([10,2*400], dtype=torch.float)
Expand All @@ -312,7 +312,7 @@ def sigma(self, x: torch.tensor) -> torch.tensor:
torch.Size([10, 400])
"""
x = x[:,self.even_index] + x[:,self.odd_index]
x = x[...,self.even_index] + x[...,self.odd_index]
x = 4*x/(self.alpha**2) # Cov is in [-1,1] so *4
return x

Expand Down

0 comments on commit ef80388

Please sign in to comment.