Skip to content

Commit fa2693f

Browse files
authored
Implement mutual information image similarity (#322)
1 parent e501ba5 commit fa2693f

File tree

3 files changed

+91
-5
lines changed

3 files changed

+91
-5
lines changed

diffdrr/_modidx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@
5656
'diffdrr/metrics.py'),
5757
'diffdrr.metrics.MultiscaleNormalizedCrossCorrelation2d.forward': ( 'api/metrics.html#multiscalenormalizedcrosscorrelation2d.forward',
5858
'diffdrr/metrics.py'),
59+
'diffdrr.metrics.MutualInformation': ('api/metrics.html#mutualinformation', 'diffdrr/metrics.py'),
60+
'diffdrr.metrics.MutualInformation.__init__': ( 'api/metrics.html#mutualinformation.__init__',
61+
'diffdrr/metrics.py'),
62+
'diffdrr.metrics.MutualInformation.forward': ( 'api/metrics.html#mutualinformation.forward',
63+
'diffdrr/metrics.py'),
5964
'diffdrr.metrics.NormalizedCrossCorrelation2d': ( 'api/metrics.html#normalizedcrosscorrelation2d',
6065
'diffdrr/metrics.py'),
6166
'diffdrr.metrics.NormalizedCrossCorrelation2d.__init__': ( 'api/metrics.html#normalizedcrosscorrelation2d.__init__',

diffdrr/metrics.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# %% auto 0
99
__all__ = ['NormalizedCrossCorrelation2d', 'MultiscaleNormalizedCrossCorrelation2d', 'GradientNormalizedCrossCorrelation2d',
10-
'LogGeodesicSE3', 'DoubleGeodesicSE3']
10+
'MutualInformation', 'LogGeodesicSE3', 'DoubleGeodesicSE3']
1111

1212
# %% ../notebooks/api/05_metrics.ipynb 6
1313
from einops import rearrange
@@ -100,7 +100,42 @@ def __init__(self, patch_size=None, sigma=1.0, **kwargs):
100100
def forward(self, x1, x2):
101101
return super().forward(self.sobel(x1), self.sobel(x2))
102102

103-
# %% ../notebooks/api/05_metrics.ipynb 14
103+
# %% ../notebooks/api/05_metrics.ipynb 11
104+
from kornia.enhance.histogram import marginal_pdf, joint_pdf
105+
106+
107+
class MutualInformation(torch.nn.Module):
108+
"""Mutual Information."""
109+
110+
def __init__(self, sigma=0.1, num_bins=256, epsilon=1e-10, normalize=True):
111+
super().__init__()
112+
self.register_buffer("sigma", torch.tensor(sigma))
113+
self.register_buffer("bins", torch.linspace(0.0, 1.0, num_bins))
114+
self.epsilon = epsilon
115+
self.normalize = normalize
116+
117+
def forward(self, x1, x2):
118+
assert x1.shape == x2.shape
119+
B, C, H, W = x1.shape
120+
121+
x1 = x1.view(B, H * W, C)
122+
x2 = x2.view(B, H * W, C)
123+
124+
pdf_x1, kernel_values1 = marginal_pdf(x1, self.bins, self.sigma, self.epsilon)
125+
pdf_x2, kernel_values2 = marginal_pdf(x2, self.bins, self.sigma, self.epsilon)
126+
pdf_x1x2 = joint_pdf(kernel_values1, kernel_values2)
127+
128+
H_x1 = -(pdf_x1 * (pdf_x1 + self.epsilon).log2()).sum(dim=1)
129+
H_x2 = -(pdf_x2 * (pdf_x2 + self.epsilon).log2()).sum(dim=1)
130+
H_x1x2 = -(pdf_x1x2 * (pdf_x1x2 + self.epsilon).log2()).sum(dim=(1, 2))
131+
132+
mutual_information = H_x1 + H_x2 - H_x1x2
133+
if self.normalize:
134+
mutual_information = 2 * mutual_information / (H_x1 + H_x2)
135+
136+
return mutual_information
137+
138+
# %% ../notebooks/api/05_metrics.ipynb 15
104139
from .pose import RigidTransform, convert
105140

106141

@@ -119,7 +154,7 @@ def forward(
119154
) -> Float[torch.Tensor, "b"]:
120155
return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)
121156

122-
# %% ../notebooks/api/05_metrics.ipynb 17
157+
# %% ../notebooks/api/05_metrics.ipynb 18
123158
from .pose import so3_log_map
124159

125160

notebooks/api/05_metrics.ipynb

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,49 @@
203203
" return super().forward(self.sobel(x1), self.sobel(x2))"
204204
]
205205
},
206+
{
207+
"cell_type": "code",
208+
"execution_count": null,
209+
"id": "6e422fc1-8100-4120-b226-f6f54602fe3c",
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"#| export\n",
214+
"from kornia.enhance.histogram import marginal_pdf, joint_pdf\n",
215+
"\n",
216+
"\n",
217+
"class MutualInformation(torch.nn.Module):\n",
218+
" \"\"\"Mutual Information.\"\"\"\n",
219+
"\n",
220+
" def __init__(self, sigma=0.1, num_bins=256, epsilon=1e-10, normalize=True):\n",
221+
" super().__init__()\n",
222+
" self.register_buffer(\"sigma\", torch.tensor(sigma))\n",
223+
" self.register_buffer(\"bins\", torch.linspace(0.0, 1.0, num_bins))\n",
224+
" self.epsilon = epsilon\n",
225+
" self.normalize = normalize\n",
226+
"\n",
227+
" def forward(self, x1, x2):\n",
228+
" assert(x1.shape == x2.shape)\n",
229+
" B, C, H, W = x1.shape\n",
230+
"\n",
231+
" x1 = x1.view(B, H * W, C)\n",
232+
" x2 = x2.view(B, H * W, C)\n",
233+
"\n",
234+
" pdf_x1, kernel_values1 = marginal_pdf(x1, self.bins, self.sigma, self.epsilon)\n",
235+
" pdf_x2, kernel_values2 = marginal_pdf(x2, self.bins, self.sigma, self.epsilon)\n",
236+
" pdf_x1x2 = joint_pdf(kernel_values1, kernel_values2)\n",
237+
"\n",
238+
" H_x1 = -(pdf_x1 * (pdf_x1 + self.epsilon).log2()).sum(dim=1)\n",
239+
" H_x2 = -(pdf_x2 * (pdf_x2 + self.epsilon).log2()).sum(dim=1)\n",
240+
" H_x1x2 = -(pdf_x1x2 * (pdf_x1x2 + self.epsilon).log2()).sum(dim=(1, 2))\n",
241+
"\n",
242+
" mutual_information = H_x1 + H_x2 - H_x1x2\n",
243+
" if self.normalize:\n",
244+
" mutual_information = 2 * mutual_information / (H_x1 + H_x2)\n",
245+
"\n",
246+
" return mutual_information"
247+
]
248+
},
206249
{
207250
"cell_type": "code",
208251
"execution_count": null,
@@ -212,7 +255,7 @@
212255
{
213256
"data": {
214257
"text/plain": [
215-
"tensor([-0.0019, -0.0004, 0.0035, -0.0198, -0.0078, -0.0175, 0.0171, 0.0019])"
258+
"tensor([ 0.0002, 0.0006, 0.0003, 0.0005, 0.0003, 0.0005, 0.0012, -0.0001])"
216259
]
217260
},
218261
"execution_count": null,
@@ -243,7 +286,10 @@
243286
"gncc(x1, x2)\n",
244287
"\n",
245288
"gncc = GradientNormalizedCrossCorrelation2d(patch_size=9)\n",
246-
"gncc(x1, x2)"
289+
"gncc(x1, x2)\n",
290+
"\n",
291+
"mi = MutualInformation()\n",
292+
"mi(x1, x2)"
247293
]
248294
},
249295
{

0 commit comments

Comments
 (0)