diff --git a/src/mrpro/data/KData.py b/src/mrpro/data/KData.py index 47ddc27d2..719141921 100644 --- a/src/mrpro/data/KData.py +++ b/src/mrpro/data/KData.py @@ -332,6 +332,13 @@ def compress_coils( from mrpro.operators import PCACompressionOp coil_dim = -4 % self.data.ndim + + if n_compressed_coils > (n_current_coils := self.data.shape[coil_dim]): + raise ValueError( + f'Number of compressed coils ({n_compressed_coils}) cannot be greater ' + f'than the number of current coils ({n_current_coils}).' + ) + if batch_dims is not None and joint_dims is not Ellipsis: raise ValueError('Either batch_dims or joint_dims can be defined not both.') @@ -349,22 +356,21 @@ def compress_coils( # reshape to (*batch dimension, -1, coils) permute_order = ( - batch_dims_normalized - + [i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized] - + [coil_dim] + *batch_dims_normalized, + *[i for i in range(self.data.ndim) if i != coil_dim and i not in batch_dims_normalized], + coil_dim, ) - kdata_coil_compressed = self.data.permute(permute_order) - permuted_kdata_shape = kdata_coil_compressed.shape - kdata_coil_compressed = kdata_coil_compressed.flatten( + kdata_permuted = self.data.permute(permute_order) + kdata_flattened = kdata_permuted.flatten( start_dim=len(batch_dims_normalized), end_dim=-2 ) # keep separate dimensions and coil - pca_compression_op = PCACompressionOp(data=kdata_coil_compressed, n_components=n_compressed_coils) - (kdata_coil_compressed,) = pca_compression_op(kdata_coil_compressed) - + pca_compression_op = PCACompressionOp(data=kdata_flattened, n_components=n_compressed_coils) + (kdata_coil_compressed_flattened,) = pca_compression_op(kdata_flattened) + del kdata_flattened # reshape to original dimensions and undo permutation kdata_coil_compressed = torch.reshape( - kdata_coil_compressed, [*permuted_kdata_shape[:-1], n_compressed_coils] + kdata_coil_compressed_flattened, [*kdata_permuted.shape[:-1], n_compressed_coils] ).permute(*np.argsort(permute_order)) return type(self)(self.header.clone(), kdata_coil_compressed, self.traj.clone()) diff --git a/tests/data/test_kdata.py b/tests/data/test_kdata.py index 42cd061f6..158b7ede6 100644 --- a/tests/data/test_kdata.py +++ b/tests/data/test_kdata.py @@ -574,3 +574,10 @@ def test_KData_compress_coils_error_coil_dim(consistently_shaped_kdata): with pytest.raises(ValueError, match='Coil dimension must not'): consistently_shaped_kdata.compress_coils(n_compressed_coils=3, joint_dims=(-4,)) + + +def test_KData_compress_coils_error_n_coils(consistently_shaped_kdata): + """Test if error is raised if new coils would be larger than existing coils""" + existing_coils = consistently_shaped_kdata.data.shape[-4] + with pytest.raises(ValueError, match='greater'): + consistently_shaped_kdata.compress_coils(existing_coils + 1)