Skip to content

Commit b94cc89

Browse files
committed
added function
1 parent a7d2276 commit b94cc89

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

utils/dice_score.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#!/usr/bin/env python3
2+
3+
from rhscripts.metrics import dice_similarity
4+
import argparse
5+
import nibabel as nib
6+
7+
__scriptname__ = 'dice_similarity'
8+
__version__ = '0.0.1'
9+
10+
"""
11+
12+
VERSIONING
13+
0.0.1 # Added functionality
14+
15+
"""
16+
17+
18+
"""Calculate DICE score between two nifti files
19+
20+
Parameters
21+
----------
22+
...
23+
version : boolean, optional
24+
Print the version of the script
25+
"""
26+
27+
""" dice_similarity
28+
Author: Claes Ladefoged
29+
Date: 31-10-2024
30+
31+
### Example use cases from cmd-line ###
32+
33+
Get overall dice score, assuming binary input:
34+
python dice_similarity.py <nii_file_1> <nii_file_2>
35+
36+
Apply a threshold on the images first:
37+
python dice_similarity.py <nii_file_1> <nii_file_2> --threshold <threshold>
38+
39+
Use a mask to limit the threshold area:
40+
python dice_similarity.py <nii_file_1> <nii_file_2> --mask <nii_file_mask>
41+
42+
---------------------------------------------------------------------------------------
43+
44+
"""
45+
46+
# INPUTS
47+
parser = argparse.ArgumentParser()
48+
parser.add_argument("nii_file_1", help='Input nii file', type=str)
49+
parser.add_argument("nii_file_2", help='Input nii file', type=str)
50+
parser.add_argument("--threshold", help='Apply a threshold to the images', type=float)
51+
parser.add_argument("--mask", help='Input mask file to limit the area', type=str)
52+
args = parser.parse_args()
53+
54+
arr1 = nib.load(args.nii_file_1).get_fdata()
55+
arr2 = nib.load(args.nii_file_2).get_fdata()
56+
57+
if args.threshold:
58+
arr1 = (arr1 > args.threshold).astype(int)
59+
arr2 = (arr2 > args.threshold).astype(int)
60+
61+
if args.mask:
62+
mask = nib.load(args.mask).get_fdata()
63+
arr1[mask<1] = 0
64+
arr2[mask<1] = 0
65+
66+
dsc = dice_similarity(arr1, arr2)
67+
print(dsc)
68+

0 commit comments

Comments
 (0)