Skip to content

Commit fe57d20

Browse files
committed
[core] Fix flow field Jacobian determinant calculation
1 parent 7e188a6 commit fe57d20

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

src/deepali/core/flow.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def divergence(
146146
kwargs = dict(mode=mode, sigma=sigma, spacing=spacing, stride=stride)
147147
which = FlowDerivativeKeys.divergence(spatial_dims=D)
148148
deriv = flow_derivatives(flow, which=which, **kwargs)
149-
div = torch.zeros((N, 1) + flow.shape[2:], dtype=flow.dtype, device=flow.device)
149+
ref = deriv["du/dx"]
150+
div = torch.zeros((N, 1) + ref.shape[2:], dtype=ref.dtype, device=ref.device)
150151
for value in deriv.values():
151152
div = div.add_(value)
152153
return div
@@ -378,22 +379,44 @@ def jacobian_det(
378379
kwargs = dict(mode=mode, sigma=sigma, spacing=spacing, stride=stride)
379380
which = FlowDerivativeKeys.jacobian(spatial_dims=D)
380381
deriv = flow_derivatives(flow, which=which, **kwargs)
381-
jac: Optional[Tensor] = None
382-
for perm in permutations(range(D)):
383-
term: Optional[Tensor] = None
384-
for i, j in zip(range(D), perm):
385-
dij = deriv[FlowDerivativeKeys.symbol(i, j)]
386-
if i == j:
387-
dij = dij.add_(1) # T(x) = x + u(x)
388-
term = dij if term is None else term.mul_(dij)
389-
assert term is not None
390-
if jac is None:
391-
jac = term
392-
elif is_even_permutation(perm):
393-
jac = jac.add_(term)
394-
else:
395-
jac = jac.sub_(term)
396-
assert jac is not None
382+
# Add 1 to diagonal elements of Jacobian matrix, because T(x) = x + u(x)
383+
for i in range(D):
384+
deriv[FlowDerivativeKeys.symbol(i, i)].add_(1)
385+
if D == 2:
386+
a = deriv["du/dx"]
387+
b = deriv["du/dy"]
388+
c = deriv["dv/dx"]
389+
d = deriv["dv/dy"]
390+
jac = a.mul(d).sub_(b.mul(c))
391+
elif D == 3:
392+
a = deriv["du/dx"]
393+
b = deriv["du/dy"]
394+
c = deriv["du/dz"]
395+
d = deriv["dv/dx"]
396+
e = deriv["dv/dy"]
397+
f = deriv["dv/dz"]
398+
g = deriv["dw/dx"]
399+
h = deriv["dw/dy"]
400+
i = deriv["dw/dz"]
401+
term_1 = a.mul(e.mul(i).sub_(f.mul(h)))
402+
term_2 = b.mul(d.mul(i).sub_(g.mul(f)))
403+
term_3 = c.mul(d.mul(h).sub_(e.mul(g)))
404+
jac = term_1.sub(term_2).add(term_3)
405+
else:
406+
jac: Optional[Tensor] = None
407+
for perm in permutations(range(D)):
408+
term: Optional[Tensor] = None
409+
for i, j in zip(range(D), perm):
410+
dij = deriv[FlowDerivativeKeys.symbol(i, j)]
411+
term = dij if term is None else term.mul_(dij)
412+
assert term is not None
413+
if jac is None:
414+
jac = term
415+
elif is_even_permutation(perm):
416+
jac = jac.add_(term)
417+
else:
418+
jac = jac.sub_(term)
419+
assert jac is not None
397420
return jac
398421

399422

0 commit comments

Comments
 (0)