Skip to content

Commit

Permalink
make align stack n-dimensional
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenBransen committed Nov 24, 2023
1 parent cdf9bf2 commit 4977475
Showing 1 changed file with 70 additions and 33 deletions.
103 changes: 70 additions & 33 deletions scm_confocal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from warnings import warn

def align_stack(images,startim=0,threshold=0,binning=1,smooth=0,upsample=1,
startoffset=(0,0),trim=True,blocksize=None,
show_process_im=False):
overlap_ratio=None,startoffset=None,cval=0,trim=True,
blocksize=None,show_process_im=False):
"""
Cross correlation alignment of image stack. Based around
skimage.feature.register_translation which enables sub-pixel precise
Expand All @@ -17,7 +17,7 @@ def align_stack(images,startim=0,threshold=0,binning=1,smooth=0,upsample=1,
Parameters
----------
images : 3d numpy array
images : Nd numpy array
the dataset which will be aligned along the first dimension (e.g. z)
startim : int
starting index that acts as reference for rest of stack
Expand All @@ -30,21 +30,22 @@ def align_stack(images,startim=0,threshold=0,binning=1,smooth=0,upsample=1,
translation
upsample : int
precision of translation in units of 1/pixel
startoffset : tuple of floats (y,x)
startoffset : tuple of floats
shift to apply to the starting image before alignment
Returns
-------
images : numpy.array
the image data with translation and (optional) trimming applied
shifts : list of (y,x) tuples
shifts : list of ([z],y,x) tuples
image shift values for each image in the dataset
"""
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift

n = len(images)
imshift = np.zeros((n,2))
ndims = len(images.shape)-1
imshift = np.zeros((n,ndims))

#check if stack is at least two images
if n == 1:
Expand All @@ -56,53 +57,81 @@ def align_stack(images,startim=0,threshold=0,binning=1,smooth=0,upsample=1,
print('startim out of range, starting at 0')
startim = 0

#apply offset to the first image if desired
if startoffset != (0,0):
images[startim] = shift(images[startim],startoffset,mode='constant',cval=0)
#make a copy of data for preprocessing
alignim = images.copy()

#make a copy of data for preprocessing only if needed to avoid clogging memory
if smooth == 0 and binning == 1 and not threshold > 0:
alignim = images
else:
alignim = images.copy()
#apply offset to the first image if desired
if startoffset != None:
images[startim] = shift(images[startim],startoffset,mode='constant',cval=cval)
imshift[startim] = startoffset

#bin, smooth and threshold data
if threshold > 0:
alignim[alignim < threshold] = 0
if binning!= 1:
alignim = bin_stack(alignim,n=(1,binning,binning),quiet=True,blocksize=blocksize)
alignim = bin_stack(alignim,n=[1]+[binning]*ndims,quiet=True,blocksize=blocksize)
if smooth != 0:
from skimage.filters import gaussian
alignim = [gaussian(im,smooth) for im in alignim]

if show_process_im:
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(alignim[n//2])
if ndims==2:
import matplotlib.pyplot as plt
plt.figure()
plt.imshow(alignim[n//2])
if ndims==3 or ndims==4:
from stackscroller import stackscroller
global scroller
scroller = stackscroller(alignim[n//2])

#start going backwards from startim to first image
for i in reversed(range(0,startim)):
print('\raligning image {:3d} of {:3d}'.format(i,n-1),end='',flush=True)
dy,dx = phase_cross_correlation(alignim[i+1],alignim[i],upsample_factor=upsample)[0]
imshift[i] = imshift[i+1] + [binning*dy, binning*dx]
images[i] = shift(images[i],imshift[i],mode='constant',cval=0)
print('\raligning image {:>{w}} of {:>{w}}'.format(i,n-1,w=len(str(n-1))),end='')
shifts = phase_cross_correlation(
alignim[i+1],
alignim[i],
upsample_factor=upsample,
overlap_ratio=overlap_ratio,
normalization=None,
return_error=False
)
imshift[i] = imshift[i+1] + [binning*s for s in shifts]
images[i] = shift(
images[i],
imshift[i],
mode='constant',
cval=cval
)

#then continue from startim to end
for i in range(startim+1,n):
print('\raligning image {:3d} of {:3d}'.format(i,n-1),end='',flush=True)
dy,dx = phase_cross_correlation(alignim[i-1],alignim[i],upsample_factor=upsample)[0]
imshift[i] = imshift[i-1] + [binning*dy, binning*dx]
images[i] = shift(images[i],imshift[i],mode='constant',cval=0)
print('\raligning image {:>{w}} of {:>{w}}'.format(i,n-1,w=len(str(n-1))),end='')
shifts = phase_cross_correlation(
alignim[i-1],
alignim[i],
upsample_factor=upsample,
overlap_ratio=overlap_ratio,
normalization=None,
return_error=False
)
imshift[i] = imshift[i-1] + [binning*s for s in shifts]
images[i] = shift(
images[i],
imshift[i],
mode='constant',
cval=cval
)
print('')

imshift[startim] = startoffset

#trim down the edges of the dataset to only area which is always in view
if trim:
images = images[:,
int(max(max(imshift[:,0]),0)):int(min(min(imshift[:,0]),0))-1,
int(max(max(imshift[:,1]),0)):int(min(min(imshift[:,1]),0))-1
]
slices = [slice(None)] + [slice(
int(max(imshift[:,d])) if max(imshift[:,d])>0 else None,
int(min(imshift[:,d])) if min(imshift[:,d])<0 else None,
None
) for d in range(ndims)]
images = images[slices]


return (images,imshift)

Expand Down Expand Up @@ -1991,12 +2020,20 @@ def _on_lim_change(call):
if ny != shape[0] or nx != shape[1]:

barsize_px = barsize_px/shape[1]*nx

#if downsampling use bicubic for smoother result
if ny <= shape[0] and nx <= shape[1]:
resample_method = Image.Resampling.BICUBIC
#if upsampling preserve original pixel size appearence
else:
resample_method = Image.Resampling.NEAREST

if multichannel:
#convert to grayscale PIL.Image, resize, convert back to array
exportim = [
np.array(Image.fromarray(im).resize(
(int(nx),int(ny)),
resample=Image.Resampling.NEAREST
resample=resample_method
)) for im in exportim
]

Expand Down

0 comments on commit 4977475

Please sign in to comment.