Skip to content

Commit aa7050f

Browse files
Calculate transfer functions on grids smaller than data (#180)
* stretched_multiply functions & tests * first pass `apply_transfer_function_filter` * check filter inputs * improved comment * test `apply_transfer_function_filter` * pad and multiply instead of iterating over stretched multiply blocks * phase reconstruction uses apply_transfer_function_filter * revised strategy --- prepad, require stretch-multiply to be divisible, then crop * fix tests --- now need divisibility * refactor inverse filter * refactor fluorescence deconvolution to support mismatched PSF and data * refactor to support transfer function banks * refactor phase and fluorescence to use "filter bank" functions * remove transverse_downsample_factor and interpolation * styling * clearer variable names * return real part only and update docs * expect real-valued input arrays * add `isotropic_thin_3d` TODO * update tests * put output on same device as input * comment Ziwen's optimization suggestions
1 parent 1f84072 commit aa7050f

File tree

7 files changed

+319
-49
lines changed

7 files changed

+319
-49
lines changed

tests/test_filter.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import torch
3+
4+
from waveorder import filter
5+
6+
7+
def test_apply_transfer_function_filter():
8+
input_array = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])
9+
transfer_function_bank = torch.tensor([[[[1, 0], [0, 0]]]])
10+
result = filter.apply_filter_bank(transfer_function_bank, input_array)
11+
expected = torch.tensor([[[10, 10], [10, 10]]]) / 4
12+
assert torch.allclose(result, expected)
13+
14+
# Test with incompatible shapes
15+
input_array = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])
16+
transfer_function_bank = torch.tensor(
17+
[[[[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]]]
18+
)
19+
with pytest.raises(ValueError):
20+
filter.apply_filter_bank(transfer_function_bank, input_array)
21+
22+
23+
def test_stretched_multiply():
24+
small_array = torch.tensor([[1, 2], [3, 4]])
25+
large_array = torch.tensor(
26+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]]
27+
)
28+
result = filter.stretched_multiply(small_array, large_array)
29+
expected = torch.tensor(
30+
[[1, 2, 6, 8], [5, 6, 14, 16], [27, 30, 44, 48], [39, 42, 60, 64]]
31+
)
32+
assert torch.all(result == expected)
33+
assert torch.all(
34+
filter.stretched_multiply(large_array, large_array) == large_array**2
35+
)
36+
37+
# Test that output dims are correct
38+
rand_array_3x3x3 = torch.rand((3, 3, 3))
39+
rand_array_99x99x99 = torch.rand((99, 99, 99))
40+
result = filter.stretched_multiply(rand_array_3x3x3, rand_array_99x99x99)
41+
assert result.shape == (99, 99, 99)
42+
43+
44+
def test_stretched_multiply_incompatible_dims():
45+
# small_array > large_array
46+
small_array = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
47+
large_array = torch.tensor([[1, 2], [3, 4]])
48+
with pytest.raises(ValueError):
49+
filter.stretched_multiply(small_array, large_array)
50+
51+
# Mismatched dims
52+
small_array = torch.tensor([[1, 2], [3, 4]])
53+
large_array = torch.tensor(
54+
[[[1, 2], [4, 5], [7, 8]], [[10, 11], [13, 14], [16, 17]]]
55+
)
56+
with pytest.raises(ValueError):
57+
filter.stretched_multiply(small_array, large_array)

waveorder/filter.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import itertools
2+
3+
import torch
4+
5+
6+
def apply_filter_bank(
7+
io_filter_bank: torch.Tensor,
8+
i_input_array: torch.Tensor,
9+
) -> torch.Tensor:
10+
"""
11+
Applies a filter bank to an input array.
12+
13+
io_filter_bank.shape must be smaller or equal to i_input_array.shape in all
14+
dimensions. When io_filter_bank is smaller, it is effectively "stretched"
15+
to apply the filter.
16+
17+
io_filter_bank is in "wrapped" format, i.e., the zero frequency is the
18+
zeroth element.
19+
20+
i_input_array and io_filter_bank must have inverse sample spacing, i.e.,
21+
is input_array contains samples spaced by dx, then io_filter_bank must
22+
have extent 1/dx. Note that there is no need for io_filter_bank to have
23+
sample spacing 1/(n*dx) because io_filter_bank will be stretched.
24+
25+
Parameters
26+
----------
27+
io_filter_bank : torch.Tensor
28+
The filter bank to be applied in the frequency domain.
29+
The spatial extent of io_filter_bank must be 1/dx, where dx is the
30+
sample spacing of i_input_array.
31+
32+
Leading dimensions are the input and output dimensions.
33+
io_filter_bank.shape[:2] == (num_input_channels, num_output_channels)
34+
35+
Trailing dimensions are spatial frequency dimensions.
36+
io_filter_bank.shape[2:] == (Z', Y', X') or (Y', X')
37+
38+
i_input_array : torch.Tensor
39+
The real-valued input array with sample spacing dx to be filtered.
40+
41+
Leading dimension is the input dimension, matching the filter bank.
42+
i_input_array.shape[0] == i
43+
44+
Trailing dimensions are spatial dimensions.
45+
i_input_array.shape[1:] == (Z, Y, X) or (Y, X)
46+
47+
Returns
48+
-------
49+
torch.Tensor
50+
The filtered real-valued output array with shape
51+
(num_output_channels, Z, Y, X) or (num_output_channels, Y, X).
52+
53+
"""
54+
55+
# Ensure all dimensions of transfer_function are smaller than or equal to input_array
56+
if any(
57+
t > i
58+
for t, i in zip(io_filter_bank.shape[2:], i_input_array.shape[1:])
59+
):
60+
raise ValueError(
61+
"All spatial dimensions of io_filter_bank must be <= i_input_array."
62+
)
63+
64+
# Ensure the number of spatial dimensions match
65+
if io_filter_bank.ndim - i_input_array.ndim != 1:
66+
raise ValueError(
67+
"io_filter_bank and i_input_array must have the same number of spatial dimensions."
68+
)
69+
70+
# Ensure the input dimensions match
71+
if io_filter_bank.shape[0] != i_input_array.shape[0]:
72+
raise ValueError(
73+
"io_filter_bank.shape[0] and i_input_array.shape[0] must be the same."
74+
)
75+
76+
num_input_channels, num_output_channels = io_filter_bank.shape[:2]
77+
spatial_dims = io_filter_bank.shape[2:]
78+
79+
# Pad input_array until each dimension is divisible by transfer_function
80+
pad_sizes = [
81+
(0, (t - (i % t)) % t)
82+
for t, i in zip(
83+
io_filter_bank.shape[2:][::-1], i_input_array.shape[1:][::-1]
84+
)
85+
]
86+
flat_pad_sizes = list(itertools.chain(*pad_sizes))
87+
padded_input_array = torch.nn.functional.pad(i_input_array, flat_pad_sizes)
88+
89+
# Apply the transfer function in the frequency domain
90+
fft_dims = [d for d in range(1, i_input_array.ndim)]
91+
padded_input_spectrum = torch.fft.fftn(padded_input_array, dim=fft_dims)
92+
93+
# Matrix-vector multiplication over f
94+
# If this is a bottleneck, consider extending `stretched_multiply` to
95+
# a `stretched_matrix_multiply` that uses an call like
96+
# torch.einsum('io..., i... -> o...', io_filter_bank, padded_input_spectrum)
97+
#
98+
# Further optimization is likely with a combination of
99+
# torch.baddbmm, torch.pixel_shuffle, torch.pixel_unshuffle.
100+
padded_output_spectrum = torch.zeros(
101+
(num_output_channels,) + spatial_dims,
102+
dtype=padded_input_spectrum.dtype,
103+
device=padded_input_spectrum.device,
104+
)
105+
for input_channel_idx in range(num_input_channels):
106+
for output_channel_idx in range(num_output_channels):
107+
padded_output_spectrum[output_channel_idx] += stretched_multiply(
108+
io_filter_bank[input_channel_idx, output_channel_idx],
109+
padded_input_spectrum[input_channel_idx],
110+
)
111+
112+
# Cast to real, ignoring imaginary part
113+
padded_result = torch.real(
114+
torch.fft.ifftn(padded_output_spectrum, dim=fft_dims)
115+
)
116+
117+
# Remove padding and return
118+
slices = tuple(slice(0, i) for i in i_input_array.shape)
119+
return padded_result[slices]
120+
121+
122+
def stretched_multiply(
123+
small_array: torch.Tensor, large_array: torch.Tensor
124+
) -> torch.Tensor:
125+
"""
126+
Effectively "stretches" small_array onto large_array before multiplying.
127+
128+
Each dimension of large_array must be divisible by each dimension of small_array.
129+
130+
Instead of upsampling small_array, this function uses a "block element-wise"
131+
multiplication by breaking the large_array into blocks before element-wise
132+
multiplication with the small_array.
133+
134+
For example, a `stretched_multiply` of a 3x3 array by a 99x99 array will
135+
divide the 99x99 array into 33x33 blocks
136+
[[33x33, 33x33, 33x33],
137+
[33x33, 33x33, 33x33],
138+
[33x33, 33x33, 33x33]]
139+
and multiply each block by the corresponding element in the 3x3 array.
140+
141+
Returns an array with the same shape as large_array.
142+
143+
Works for arbitrary dimensions.
144+
145+
Parameters
146+
----------
147+
small_array : torch.Tensor
148+
A smaller array whose elements will be "stretched" onto blocks in the large array.
149+
large_array : torch.Tensor
150+
A larger array that will be divided into blocks and multiplied by the small array.
151+
152+
Returns
153+
-------
154+
torch.Tensor
155+
Resulting tensor with shape matching large_array.
156+
157+
Example
158+
-------
159+
small_array = torch.tensor([[1, 2],
160+
[3, 4]])
161+
162+
large_array = torch.tensor([[1, 2, 3, 4],
163+
[5, 6, 7, 8],
164+
[9, 10, 11, 12],
165+
[13, 14, 15, 16]])
166+
167+
stretched_multiply(small_array, large_array) returns
168+
169+
[[ 1, 2, 6, 8],
170+
[ 5, 6, 14, 16],
171+
[ 27, 30, 44, 48],
172+
[ 39, 42, 60, 64]]
173+
"""
174+
175+
# Ensure each dimension of large_array is divisible by each dimension of small_array
176+
if any(l % s != 0 for s, l in zip(small_array.shape, large_array.shape)):
177+
raise ValueError(
178+
"Each dimension of large_array must be divisible by each dimension of small_array"
179+
)
180+
181+
# Ensure the number of dimensions match
182+
if small_array.ndim != large_array.ndim:
183+
raise ValueError(
184+
"small_array and large_array must have the same number of dimensions"
185+
)
186+
187+
# Get shapes
188+
s_shape = small_array.shape
189+
l_shape = large_array.shape
190+
191+
# Reshape both array into blocks
192+
block_shape = tuple(p // s for p, s in zip(l_shape, s_shape))
193+
new_large_shape = tuple(itertools.chain(*zip(s_shape, block_shape)))
194+
new_small_shape = tuple(
195+
itertools.chain(*zip(s_shape, small_array.ndim * (1,)))
196+
)
197+
reshaped_large_array = large_array.reshape(new_large_shape)
198+
reshaped_small_array = small_array.reshape(new_small_shape)
199+
200+
# Multiply the reshaped arrays
201+
reshaped_result = reshaped_large_array * reshaped_small_array
202+
203+
# Reshape the result back to the large array shape
204+
result = reshaped_result.reshape(l_shape)
205+
206+
return result

waveorder/models/inplane_oriented_thick_pol3d_vector.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import numpy as np
44
import torch
55
from torch import Tensor
6-
from torch.nn.functional import avg_pool3d, interpolate
6+
from torch.nn.functional import avg_pool3d
77

88
from waveorder import optics, sampling, stokes, util
9+
from waveorder.filter import apply_filter_bank
910
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
1011

1112

@@ -40,7 +41,6 @@ def calculate_transfer_function(
4041
numerical_aperture_detection: float,
4142
invert_phase_contrast: bool = False,
4243
fourier_oversample_factor: int = 1,
43-
transverse_downsample_factor: int = 1,
4444
) -> tuple[
4545
torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]
4646
]:
@@ -66,22 +66,8 @@ def calculate_transfer_function(
6666

6767
tf_calculation_shape = (
6868
zyx_shape[0] * z_factor * fourier_oversample_factor,
69-
int(
70-
np.ceil(
71-
zyx_shape[1]
72-
* yx_factor
73-
* fourier_oversample_factor
74-
/ transverse_downsample_factor
75-
)
76-
),
77-
int(
78-
np.ceil(
79-
zyx_shape[2]
80-
* yx_factor
81-
* fourier_oversample_factor
82-
/ transverse_downsample_factor
83-
)
84-
),
69+
int(np.ceil(zyx_shape[1] * yx_factor * fourier_oversample_factor)),
70+
int(np.ceil(zyx_shape[2] * yx_factor * fourier_oversample_factor)),
8571
)
8672

8773
(
@@ -125,25 +111,12 @@ def calculate_transfer_function(
125111
)
126112

127113
# Compute singular system on cropped and downsampled
128-
U, S, Vh = calculate_singular_system(cropped)
129-
130-
# Interpolate to final size in YX
131-
def complex_interpolate(
132-
tensor: torch.Tensor, zyx_shape: tuple[int, int, int]
133-
) -> torch.Tensor:
134-
interpolated_real = interpolate(tensor.real, size=zyx_shape)
135-
interpolated_imag = interpolate(tensor.imag, size=zyx_shape)
136-
return interpolated_real + 1j * interpolated_imag
137-
138-
full_cropped = complex_interpolate(cropped, zyx_shape)
139-
full_U = complex_interpolate(U, zyx_shape)
140-
full_S = interpolate(S[None], size=zyx_shape)[0] # S is real
141-
full_Vh = complex_interpolate(Vh, zyx_shape)
114+
singular_system = calculate_singular_system(cropped)
142115

143116
return (
144-
full_cropped,
117+
cropped,
145118
intensity_to_stokes_matrix,
146-
(full_U, full_S, full_Vh),
119+
singular_system,
147120
)
148121

149122

@@ -334,20 +307,14 @@ def apply_inverse_transfer_function(
334307
TV_rho_strength: float = 1e-3,
335308
TV_iterations: int = 10,
336309
):
337-
sZYX_data = torch.fft.fftn(szyx_data, dim=(1, 2, 3))
338-
339310
# Key computation
340311
print("Computing inverse filter")
341312
U, S, Vh = singular_system
342313
S_reg = S / (S**2 + regularization_strength)
343-
344-
ZYXsf_inverse_filter = torch.einsum(
314+
sfzyx_inverse_filter = torch.einsum(
345315
"sjzyx,jzyx,jfzyx->sfzyx", U, S_reg, Vh
346316
)
347317

348-
# Apply inverse filter
349-
fZYX_reconstructed = torch.einsum(
350-
"szyx,sfzyx->fzyx", sZYX_data, ZYXsf_inverse_filter
351-
)
318+
fzyx_recon = apply_filter_bank(sfzyx_inverse_filter, szyx_data)
352319

353-
return torch.real(torch.fft.ifftn(fZYX_reconstructed, dim=(1, 2, 3)))
320+
return fzyx_recon

waveorder/models/isotropic_fluorescent_thick_3d.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from torch import Tensor
66

77
from waveorder import optics, sampling, util
8+
from waveorder.filter import apply_filter_bank
9+
from waveorder.reconstruct import tikhonov_regularized_inverse_filter
810
from waveorder.visuals.napari_visuals import add_transfer_function_to_viewer
911

1012

@@ -211,12 +213,15 @@ def apply_inverse_transfer_function(
211213

212214
# Reconstruct
213215
if reconstruction_algorithm == "Tikhonov":
214-
f_real = util.single_variable_tikhonov_deconvolution_3D(
215-
zyx_padded,
216-
optical_transfer_function,
217-
reg_re=regularization_strength,
216+
inverse_filter = tikhonov_regularized_inverse_filter(
217+
optical_transfer_function, regularization_strength
218218
)
219219

220+
# [None]s and [0] are for applying a 1x1 "bank" of filters.
221+
# For further uniformity, consider returning (1, Z, Y, X)
222+
f_real = apply_filter_bank(
223+
inverse_filter[None, None], zyx_padded[None]
224+
)[0]
220225
elif reconstruction_algorithm == "TV":
221226
raise NotImplementedError
222227
f_real = util.single_variable_admm_tv_deconvolution_3D(

waveorder/models/isotropic_thin_3d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def apply_inverse_transfer_function(
288288
zyx_data_hat = torch.fft.fft2(zyx_data_normalized, dim=(1, 2))
289289

290290
# TODO AHA and b_vec calculations should be moved into tikhonov/tv calculations
291+
# TODO Reformulate to use filter.apply_filter_bank
291292
AHA = [
292293
torch.sum(torch.abs(absorption_2d_to_3d_transfer_function) ** 2, dim=0)
293294
+ regularization_strength,

0 commit comments

Comments
 (0)