@@ -146,7 +146,8 @@ def divergence(
146
146
kwargs = dict (mode = mode , sigma = sigma , spacing = spacing , stride = stride )
147
147
which = FlowDerivativeKeys .divergence (spatial_dims = D )
148
148
deriv = flow_derivatives (flow , which = which , ** kwargs )
149
- div = torch .zeros ((N , 1 ) + flow .shape [2 :], dtype = flow .dtype , device = flow .device )
149
+ ref = deriv ["du/dx" ]
150
+ div = torch .zeros ((N , 1 ) + ref .shape [2 :], dtype = ref .dtype , device = ref .device )
150
151
for value in deriv .values ():
151
152
div = div .add_ (value )
152
153
return div
@@ -378,22 +379,44 @@ def jacobian_det(
378
379
kwargs = dict (mode = mode , sigma = sigma , spacing = spacing , stride = stride )
379
380
which = FlowDerivativeKeys .jacobian (spatial_dims = D )
380
381
deriv = flow_derivatives (flow , which = which , ** kwargs )
381
- jac : Optional [Tensor ] = None
382
- for perm in permutations (range (D )):
383
- term : Optional [Tensor ] = None
384
- for i , j in zip (range (D ), perm ):
385
- dij = deriv [FlowDerivativeKeys .symbol (i , j )]
386
- if i == j :
387
- dij = dij .add_ (1 ) # T(x) = x + u(x)
388
- term = dij if term is None else term .mul_ (dij )
389
- assert term is not None
390
- if jac is None :
391
- jac = term
392
- elif is_even_permutation (perm ):
393
- jac = jac .add_ (term )
394
- else :
395
- jac = jac .sub_ (term )
396
- assert jac is not None
382
+ # Add 1 to diagonal elements of Jacobian matrix, because T(x) = x + u(x)
383
+ for i in range (D ):
384
+ deriv [FlowDerivativeKeys .symbol (i , i )].add_ (1 )
385
+ if D == 2 :
386
+ a = deriv ["du/dx" ]
387
+ b = deriv ["du/dy" ]
388
+ c = deriv ["dv/dx" ]
389
+ d = deriv ["dv/dy" ]
390
+ jac = a .mul (d ).sub_ (b .mul (c ))
391
+ elif D == 3 :
392
+ a = deriv ["du/dx" ]
393
+ b = deriv ["du/dy" ]
394
+ c = deriv ["du/dz" ]
395
+ d = deriv ["dv/dx" ]
396
+ e = deriv ["dv/dy" ]
397
+ f = deriv ["dv/dz" ]
398
+ g = deriv ["dw/dx" ]
399
+ h = deriv ["dw/dy" ]
400
+ i = deriv ["dw/dz" ]
401
+ term_1 = a .mul (e .mul (i ).sub_ (f .mul (h )))
402
+ term_2 = b .mul (d .mul (i ).sub_ (g .mul (f )))
403
+ term_3 = c .mul (d .mul (h ).sub_ (e .mul (g )))
404
+ jac = term_1 .sub (term_2 ).add (term_3 )
405
+ else :
406
+ jac : Optional [Tensor ] = None
407
+ for perm in permutations (range (D )):
408
+ term : Optional [Tensor ] = None
409
+ for i , j in zip (range (D ), perm ):
410
+ dij = deriv [FlowDerivativeKeys .symbol (i , j )]
411
+ term = dij if term is None else term .mul_ (dij )
412
+ assert term is not None
413
+ if jac is None :
414
+ jac = term
415
+ elif is_even_permutation (perm ):
416
+ jac = jac .add_ (term )
417
+ else :
418
+ jac = jac .sub_ (term )
419
+ assert jac is not None
397
420
return jac
398
421
399
422
0 commit comments