Skip to content

Commit 5238b3c

Browse files
authored
Merge pull request #714 from ANTsX/add_tests
BUG: label_image_centroids only worked with sequential labels 1..N
2 parents be10edc + ce8667d commit 5238b3c

File tree

6 files changed

+86
-48
lines changed

6 files changed

+86
-48
lines changed

ants/label/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .label_overlap_measures import label_overlap_measures
66
from .label_stats import label_stats
77
from .labels_to_matrix import labels_to_matrix
8-
from .multi_label_morphology import multi_label_morphology
8+
from .make_points_image import make_points_image
9+
from .multi_label_morphology import multi_label_morphology

ants/label/label_image_centroids.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ def label_image_centroids(image, physical=False, convex=True, verbose=False):
1818
---------
1919
image : ANTsImage
2020
image of integer labels
21-
21+
2222
physical : boolean
2323
whether you want physical space coordinates or not
24-
24+
2525
convex : boolean
2626
if True, return centroid
2727
if False return point with min average distance to other points with same label
28-
28+
2929
Returns
3030
-------
3131
dictionary w/ following key-value pairs:
@@ -58,14 +58,14 @@ def label_image_centroids(image, physical=False, convex=True, verbose=False):
5858
zc = np.zeros(n_labels)
5959

6060
if convex:
61-
for i in mylabels:
62-
idx = (labels == i).flatten()
63-
xc[i-1] = np.mean(xcoords[idx])
64-
yc[i-1] = np.mean(ycoords[idx])
65-
zc[i-1] = np.mean(zcoords[idx])
61+
for lab_idx, label_intensity in enumerate(mylabels):
62+
idx = (labels == label_intensity).flatten()
63+
xc[lab_idx] = np.mean(xcoords[idx])
64+
yc[lab_idx] = np.mean(ycoords[idx])
65+
zc[lab_idx] = np.mean(zcoords[idx])
6666
else:
67-
for i in mylabels:
68-
idx = (labels == i).flatten()
67+
for lab_idx, label_intensity in enumerate(mylabels):
68+
idx = (labels == label_intensity).flatten()
6969
xci = xcoords[idx]
7070
yci = ycoords[idx]
7171
zci = zcoords[idx]
@@ -75,9 +75,9 @@ def label_image_centroids(image, physical=False, convex=True, verbose=False):
7575
dist[j] = np.mean(np.sqrt((xci[j] - xci)**2 + (yci[j] - yci)**2 + (zci[j] - zci)**2))
7676

7777
mid = np.where(dist==np.min(dist))
78-
xc[i-1] = xci[mid]
79-
yc[i-1] = yci[mid]
80-
zc[i-1] = zci[mid]
78+
xc[lab_idx] = xci[mid]
79+
yc[lab_idx] = yci[mid]
80+
zc[lab_idx] = zci[mid]
8181

8282
centroids = np.vstack([xc,yc,zc]).T
8383

ants/label/make_points_image.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import ants
99

10-
def make_points_image(pts, mask, radius=5):
10+
def make_points_image(pts, target, radius=5):
1111
"""
1212
Create label image from physical space points
1313
@@ -21,10 +21,10 @@ def make_points_image(pts, mask, radius=5):
2121
Arguments
2222
---------
2323
pts : numpy.ndarray
24-
input powers points
24+
input points
2525
26-
mask : ANTsImage
27-
mask defining target space
26+
target : ANTsImage
27+
Image defining target space
2828
2929
radius : integer
3030
radius for the points
@@ -37,25 +37,25 @@ def make_points_image(pts, mask, radius=5):
3737
-------
3838
>>> import ants
3939
>>> import pandas as pd
40-
>>> mni = ants.image_read(ants.get_data('mni')).get_mask()
40+
>>> mni = ants.image_read(ants.get_data('mni'))
4141
>>> powers_pts = pd.read_csv(ants.get_data('powers_mni_itk'))
4242
>>> powers_labels = ants.make_points_image(powers_pts.iloc[:,:3].values, mni, radius=3)
4343
"""
44-
powers_lblimg = mask * 0
44+
lblimg = target * 0
4545
npts = len(pts)
46-
dim = mask.dimension
46+
dim = target.dimension
4747
if pts.shape[1] != dim:
4848
raise ValueError('points dimensionality should match that of images')
4949

5050
for r in range(npts):
5151
pt = pts[r,:]
52-
idx = ants.transform_physical_point_to_index(mask, pt.tolist() ).astype(int)
52+
idx = ants.transform_physical_point_to_index(target, pt.tolist() ).astype(int)
5353
in_image=True
54-
for kk in range(mask.dimension):
55-
in_image = in_image and idx[kk] >= 0 and idx[kk] < mask.shape[kk]
54+
for kk in range(target.dimension):
55+
in_image = in_image and idx[kk] >= 0 and idx[kk] < target.shape[kk]
5656
if ( in_image == True ):
5757
if (dim == 3):
58-
powers_lblimg[idx[0],idx[1],idx[2]] = r + 1
58+
lblimg[idx[0],idx[1],idx[2]] = r + 1
5959
elif (dim == 2):
60-
powers_lblimg[idx[0],idx[1]] = r + 1
61-
return ants.morphology( powers_lblimg, 'dilate', radius, 'grayscale' )
60+
lblimg[idx[0],idx[1]] = r + 1
61+
return ants.morphology( lblimg, 'dilate', radius, 'grayscale' )

ants/registration/affine_initializer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88

99

1010
def affine_initializer(fixed_image, moving_image, search_factor=20,
11-
radian_fraction=0.1, use_principal_axis=False,
11+
radian_fraction=0.1, use_principal_axis=False,
1212
local_search_iterations=10, mask=None, txfn=None ):
1313
"""
1414
A multi-start optimizer for affine registration
1515
Searches over the sphere to find a good initialization for further
1616
registration refinement, if needed. This is a wrapper for the ANTs
1717
function antsAffineInitializer.
18-
18+
1919
ANTsR function: `affineInitializer`
2020
2121
Arguments
2222
---------
2323
fixed_image : ANTsImage
2424
the fixed reference image
25-
moving_image : ANTsImage
25+
moving_image : ANTsImage
2626
the moving image to be mapped to the fixed space
2727
search_factor : scalar
2828
degree of increments on the sphere to search
@@ -41,7 +41,7 @@ def affine_initializer(fixed_image, moving_image, search_factor=20,
4141
-------
4242
ndarray
4343
transformation matrix
44-
44+
4545
Example
4646
-------
4747
>>> import ants
@@ -66,6 +66,5 @@ def affine_initializer(fixed_image, moving_image, search_factor=20,
6666

6767
if retval != 0:
6868
warnings.warn('ERROR: Non-zero exit status!')
69-
70-
return txfn
7169

70+
return txfn

ants/utils/mni2tal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def mni2tal(xin):
2525
2626
References
2727
----------
28-
http://bioimagesuite.yale.edu/mni2tal/501_95733_More\%20Accurate\%20Talairach\%20Coordinates\%20SLIDES.pdf
28+
http://bioimagesuite.yale.edu/mni2tal/501_95733_More\\%20Accurate\\%20Talairach\\%20Coordinates\\%20SLIDES.pdf
2929
http://imaging.mrc-cbu.cam.ac.uk/imaging/MniTalairach
3030
"""
3131
if (not isinstance(xin, (tuple,list))) or (len(xin) != 3):

tests/test_utils.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,13 @@ def test_crop_image_example(self):
209209

210210
# label image not float
211211
cropped = ants.crop_image(fi, fi.clone("unsigned int"), 100)
212-
212+
213213
# channel image
214214
fi = ants.image_read( ants.get_ants_data('r16') )
215215
cropped = ants.crop_image(fi)
216216
fi2 = ants.merge_channels([fi,fi])
217217
cropped2 = ants.crop_image(fi2)
218-
218+
219219
self.assertEqual(cropped.shape, cropped2.shape)
220220

221221
def test_crop_indices_example(self):
@@ -583,11 +583,27 @@ def setUp(self):
583583
def tearDown(self):
584584
pass
585585

586-
def test_label_clusters_example(self):
586+
def test_label_image_centroids(self):
587587
image = ants.from_numpy(
588588
np.asarray([[[0, 2], [1, 3]], [[4, 6], [5, 7]]]).astype("float32")
589589
)
590590
labels = ants.label_image_centroids(image)
591+
self.assertEqual(len(labels['labels']), 7)
592+
593+
# Test non-sequential labels
594+
image = ants.from_numpy(
595+
np.asarray([[[0, 2], [2, 2]], [[2, 0], [5, 0]]]).astype("float32")
596+
)
597+
598+
labels = ants.label_image_centroids(image)
599+
self.assertTrue(len(labels['labels']) == 2)
600+
self.assertTrue(labels['labels'][1] == 5)
601+
self.assertTrue(np.allclose(labels['vertices'][0], [0.5 , 0.5 , 0.25], atol=1e-5))
602+
# With convex = False, the centroid position should change
603+
labels = ants.label_image_centroids(image, convex=False)
604+
self.assertTrue(np.allclose(labels['vertices'][0], [1.0, 1.0, 0.0], atol=1e-5))
605+
# single point unchanged
606+
self.assertTrue(np.allclose(labels['vertices'][1], [0.0, 1.0, 1.0], atol=1e-5))
591607

592608

593609
class TestModule_label_overlap_measures(unittest.TestCase):
@@ -633,6 +649,28 @@ def test_labels_to_matrix_example(self):
633649
labmat = ants.labels_to_matrix(labs, mask)
634650

635651

652+
class TestModule_make_points_image(unittest.TestCase):
653+
def setUp(self):
654+
pass
655+
656+
def tearDown(self):
657+
pass
658+
659+
def test_make_points_image_example(self):
660+
image = ants.image_read(ants.get_ants_data("r16"))
661+
points = np.array([[102, 76],[134, 129]])
662+
points_image = ants.make_points_image(points, image, radius=5)
663+
stats = ants.label_stats(image, points_image)
664+
self.assertTrue(np.allclose(stats['Volume'].to_numpy()[1:3], 97.0, atol=1e-5))
665+
self.assertTrue(np.allclose(stats['x'].to_numpy()[1:3], points[:,0], atol=1e-5))
666+
self.assertTrue(np.allclose(stats['y'].to_numpy()[1:3], points[:,1], atol=1e-5))
667+
668+
points = np.array([[102, 76, 50],[134, 129, 50]])
669+
# Shouldn't allow 3D points on a 2D image
670+
with self.assertRaises(Exception):
671+
points_image = ants.make_points_image(image, points, radius=3)
672+
673+
636674
class TestModule_mask_image(unittest.TestCase):
637675
def setUp(self):
638676
pass
@@ -846,7 +884,7 @@ def setUp(self):
846884
pass
847885
def tearDown(self):
848886
pass
849-
887+
850888
def test_bspline_field(self):
851889
points = np.array([[-50, -50]])
852890
deltas = np.array([[10, 10]])
@@ -874,29 +912,29 @@ def test_ilr(self):
874912
result = ants.ilr( df, vlist, myform)
875913
myform = " mat2 ~ covar + mat1 "
876914
result = ants.ilr( df, vlist, myform)
877-
915+
878916
def test_quantile(self):
879917
img = ants.image_read(ants.get_data('r16'))
880918
ants.quantile(img, 0.5)
881919
ants.quantile(img, (0.5, 0.75))
882-
920+
883921
def test_bandpass(self):
884922
brainSignal = np.random.randn( 400, 1000 )
885923
tr = 1
886924
filtered = ants.bandpass_filter_matrix( brainSignal, tr = tr )
887-
925+
888926
def test_compcorr(self):
889927
cc = ants.compcor( ants.image_read(ants.get_ants_data("ch2")) )
890928

891929
def test_histogram_match(self):
892930
src_img = ants.image_read(ants.get_data('r16'))
893931
ref_img = ants.image_read(ants.get_data('r64'))
894932
src_ref = ants.histogram_match_image(src_img, ref_img)
895-
933+
896934
src_img = ants.image_read(ants.get_data('r16'))
897935
ref_img = ants.image_read(ants.get_data('r64'))
898936
src_ref = ants.histogram_match_image2(src_img, ref_img)
899-
937+
900938
def test_averaging(self):
901939
x0=[ ants.get_data('r16'), ants.get_data('r27'), ants.get_data('r62'), ants.get_data('r64') ]
902940
x1=[]
@@ -906,7 +944,7 @@ def test_averaging(self):
906944
avg1=ants.average_images(x1)
907945
avg2=ants.average_images(x1,mask=0)
908946
avg3=ants.average_images(x1,mask=1,normalize=True)
909-
947+
910948
def test_n3_2(self):
911949
image = ants.image_read( ants.get_ants_data('r16') )
912950
image_n3 = ants.n3_bias_field_correction2(image)
@@ -929,29 +967,29 @@ def test_thin_plate_spline(self):
929967
displacement_origins=points, displacements=deltas,
930968
origin=[0.0, 0.0], spacing=[1.0, 1.0], size=[100, 100],
931969
direction=np.array([[-1, 0], [0, -1]]))
932-
970+
933971
def test_multi_label_morph(self):
934972
img = ants.image_read(ants.get_data('r16'))
935973
labels = ants.get_mask(img,1,150) + ants.get_mask(img,151,225) * 2
936974
labels_dilated = ants.multi_label_morphology(labels, 'MD', 2)
937975
# should see original label regions preserved in dilated version
938976
# label N should have mean N and 0 variance
939977
print(ants.label_stats(labels_dilated, labels))
940-
978+
941979
def test_hausdorff_distance(self):
942980
r16 = ants.image_read( ants.get_ants_data('r16') )
943981
r64 = ants.image_read( ants.get_ants_data('r64') )
944982
s16 = ants.kmeans_segmentation( r16, 3 )['segmentation']
945983
s64 = ants.kmeans_segmentation( r64, 3 )['segmentation']
946984
stats = ants.hausdorff_distance(s16, s64)
947-
985+
948986
def test_channels_first(self):
949987
import ants
950988
image = ants.image_read(ants.get_ants_data('r16'))
951989
image2 = ants.image_read(ants.get_ants_data('r16'))
952990
img3 = ants.merge_channels([image,image2])
953991
img4 = ants.merge_channels([image,image2], channels_first=True)
954-
992+
955993
self.assertTrue(np.allclose(img3.numpy()[:,:,0], img4.numpy()[0,:,:]))
956994
self.assertTrue(np.allclose(img3.numpy()[:,:,1], img4.numpy()[1,:,:]))
957995

0 commit comments

Comments
 (0)