Skip to content

Commit

Permalink
Merge pull request #126 from uncbiag/new_tv_norm
Browse files Browse the repository at this point in the history
New tv norm
  • Loading branch information
marcniethammer authored Jul 9, 2018
2 parents c6d4662 + 876d226 commit 6d81ccc
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 56 deletions.
1 change: 1 addition & 0 deletions experiments/multi_stage_smoother_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def load_image_pair_configuration( filename_pt, input_image_directory ):
if args.seed is not None:
print('Setting the random seed to {:}'.format(args.seed))
random.seed(args.seed)
torch.manual_seed(args.seed)

print('Loading settings from file: ' + args.config)
params = pars.ParameterDict()
Expand Down
6 changes: 3 additions & 3 deletions experiments/test2d_025.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"deep_network_local_weight_smoothing": 0.02,
"diffusion_weight_penalty": 0.0,
"edge_penalty_filename": "DEBUG_edge_penalty.nrrd",
"edge_penalty_gamma": 15.0,
"edge_penalty_terminate_after_writing": true,
"edge_penalty_write_to_file": true,
"edge_penalty_gamma": 10.0,
"edge_penalty_terminate_after_writing": false,
"edge_penalty_write_to_file": false,
"estimate_around_global_weights": true,
"kernel_sizes": [
7,
Expand Down
122 changes: 71 additions & 51 deletions pyreg/deep_smoothers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def half_sigmoid(x,alpha=1):
def _compute_localized_edge_penalty(I,spacing,gamma):
# needs to be batch B x X x Y x Z format
fdt = fd.FD_torch(spacing=spacing)
gnI = float(np.min(spacing)) * fdt.grad_norm_sqr(I) ** 0.5
gnI = float(np.min(spacing)) * fdt.grad_norm_sqr_f(I) ** 0.5

# compute edge penalty
localized_edge_penalty = 1.0 / (1.0 + gamma * gnI) # this is what we weight the OMT values with
Expand Down Expand Up @@ -217,7 +217,7 @@ def _compute_weighted_total_variation_1d(d_in,w, spacing, bc_val, pnorm=2):
# need to use torch.abs here to make sure the proper subgradient is computed at zero
batch_size = d.size()[0]
volumeElement = spacing.prod()
t0 = torch.abs(fdt.dXc(d))
t0 = torch.abs(fdt.dXf(d))

tm = t0*w

Expand All @@ -238,7 +238,7 @@ def _compute_weighted_total_variation_2d(d_in,w, spacing, bc_val, pnorm=2):
# need to use torch.norm here to make sure the proper subgradient is computed at zero
batch_size = d.size()[0]
volumeElement = spacing.prod()
t0 = torch.norm(torch.stack((fdt.dXc(d),fdt.dYc(d))),pnorm,0)
t0 = torch.norm(torch.stack((fdt.dXf(d),fdt.dYf(d))),pnorm,0)

tm = t0*w

Expand All @@ -262,9 +262,9 @@ def _compute_weighted_total_variation_3d(d_in,w, spacing, bc_val, pnorm=2):
batch_size = d.size()[0]
volumeElement = spacing.prod()

t0 = torch.norm(torch.stack((fdt.dXc(d),
fdt.dYc(d),
fdt.dZc(d))), pnorm, 0)
t0 = torch.norm(torch.stack((fdt.dXf(d),
fdt.dYf(d),
fdt.dZf(d))), pnorm, 0)

tm = t0*w

Expand All @@ -289,7 +289,7 @@ def _compute_local_norm_of_gradient_1d(d,spacing,pnorm=2):

fdt = fd.FD_torch(spacing=spacing)
# need to use torch.abs here to make sure the proper subgradient is computed at zero
t0 = torch.abs(fdt.dXc(d))
t0 = torch.abs(fdt.dXf(d))

return t0

Expand All @@ -299,8 +299,8 @@ def _compute_local_norm_of_gradient_2d(d,spacing,pnorm=2):
# need to use torch.norm here to make sure the proper subgradient is computed at zero
#t0 = torch.norm(torch.stack((fdt.dXc(d),fdt.dYc(d))),pnorm,0)

dX = fdt.dXc(d)
dY = fdt.dYc(d)
dX = fdt.dXf(d)
dY = fdt.dYf(d)

t0 = torch.norm(torch.stack((dX, dY)), pnorm, 0)

Expand All @@ -318,9 +318,9 @@ def _compute_local_norm_of_gradient_3d(d,spacing, pnorm=2):
fdt = fd.FD_torch(spacing=spacing)
# need to use torch.norm here to make sure the proper subgradient is computed at zero

t0 = torch.norm(torch.stack((fdt.dXc(d),
fdt.dYc(d),
fdt.dZc(d))), pnorm, 0)
t0 = torch.norm(torch.stack((fdt.dXf(d),
fdt.dYf(d),
fdt.dZf(d))), pnorm, 0)

return t0

Expand Down Expand Up @@ -652,36 +652,43 @@ def weighted_softmax(input, dim=None, weights=None ):

ret = torch.zeros_like(input)

# for numerical reasons we first compute the maximum inout along the dimension and then
# subtract if from all the exponents (this assures that we do not get exp(100) and then a NaN
# this is ok, because we can multiply the nominator and denominator with the same constant
# and by doing this shift the exponentials

max_in,_ = torch.max(input, dim=dim)

if dim==0:
norm = torch.zeros_like(input[0,...])
for c in range(sz[0]):
norm += weights[c]*torch.exp(input[c,...])
norm += weights[c]*torch.exp(input[c,...]-max_in)
for c in range(sz[0]):
ret[c,...] = weights[c]*torch.exp(input[c,...])/norm
ret[c,...] = weights[c]*torch.exp(input[c,...]-max_in)/norm
elif dim==1:
norm = torch.zeros_like(input[:,0, ...])
for c in range(sz[1]):
norm += weights[c] * torch.exp(input[:,c, ...])
norm += weights[c] * torch.exp(input[:,c, ...]-max_in)
for c in range(sz[1]):
ret[:,c, ...] = weights[c] * torch.exp(input[:,c, ...]) / norm
ret[:,c, ...] = weights[c] * torch.exp(input[:,c, ...]-max_in) / norm
elif dim==2:
norm = torch.zeros_like(input[:,:,0, ...])
for c in range(sz[2]):
norm += weights[c] * torch.exp(input[:,:,c, ...])
norm += weights[c] * torch.exp(input[:,:,c, ...]-max_in)
for c in range(sz[2]):
ret[:,:,c, ...] = weights[c] * torch.exp(input[:,:,c, ...]) / norm
ret[:,:,c, ...] = weights[c] * torch.exp(input[:,:,c, ...]-max_in) / norm
elif dim==3:
norm = torch.zeros_like(input[:,:,:,0, ...])
for c in range(sz[3]):
norm += weights[c] * torch.exp(input[:,:,:,c, ...])
norm += weights[c] * torch.exp(input[:,:,:,c, ...]-max_in)
for c in range(sz[3]):
ret[:,:,:,c, ...] = weights[c] * torch.exp(input[:,:,:,c, ...]) / norm
ret[:,:,:,c, ...] = weights[c] * torch.exp(input[:,:,:,c, ...]-max_in) / norm
elif dim==4:
norm = torch.zeros_like(input[:,:,:,:,0, ...])
for c in range(sz[4]):
norm += weights[c] * torch.exp(input[:,:,:,:,c, ...])
norm += weights[c] * torch.exp(input[:,:,:,:,c, ...]-max_in)
for c in range(sz[4]):
ret[:,:,:,:,c, ...] = weights[c] * torch.exp(input[:,:,:,:,c, ...]) / norm
ret[:,:,:,:,c, ...] = weights[c] * torch.exp(input[:,:,:,:,c, ...]-max_in) / norm
else:
raise ValueError('weighted_softmax is only supported for dimensions 0, 1, 2, 3, and 4.')

Expand Down Expand Up @@ -769,41 +776,48 @@ def weighted_sqrt_softmax(input, dim=None, weights=None ):

ret = torch.zeros_like(input)

# for numerical reasons we first compute the maximum inout along the dimension and then
# subtract if from all the exponents (this assures that we do not get exp(100) and then a NaN
# this is ok, because we can multiply the nominator and denominator with the same constant
# and by doing this shift the exponentials

max_in, _ = torch.max(input, dim=dim)

if dim==0:
norm_sqr = torch.zeros_like(input[0,...])
for c in range(sz[0]):
norm_sqr += weights[c]*(torch.exp(input[c,...]))**2
norm_sqr += weights[c]*(torch.exp(input[c,...]-max_in))**2
norm = torch.sqrt(norm_sqr)
for c in range(sz[0]):
ret[c,...] = torch.sqrt(weights[c])*torch.exp(input[c,...])/norm
ret[c,...] = torch.sqrt(weights[c])*torch.exp(input[c,...]-max_in)/norm
elif dim==1:
norm_sqr = torch.zeros_like(input[:,0, ...])
for c in range(sz[1]):
norm_sqr += weights[c] * (torch.exp(input[:,c, ...]))**2
norm_sqr += weights[c] * (torch.exp(input[:,c, ...]-max_in))**2
norm = torch.sqrt(norm_sqr)
for c in range(sz[1]):
ret[:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,c, ...]) / norm
ret[:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,c, ...]-max_in) / norm
elif dim==2:
norm_sqr = torch.zeros_like(input[:,:,0, ...])
for c in range(sz[2]):
norm_sqr += weights[c] * (torch.exp(input[:,:,c, ...]))**2
norm_sqr += weights[c] * (torch.exp(input[:,:,c, ...]-max_in))**2
norm = torch.sqrt(norm_sqr)
for c in range(sz[2]):
ret[:,:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,:,c, ...]) / norm
ret[:,:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,:,c, ...]-max_in) / norm
elif dim==3:
norm_sqr = torch.zeros_like(input[:,:,:,0, ...])
for c in range(sz[3]):
norm_sqr += weights[c] * (torch.exp(input[:,:,:,c, ...]))**2
norm_sqr += weights[c] * (torch.exp(input[:,:,:,c, ...]-max_in))**2
norm = torch.sqrt(norm_sqr)
for c in range(sz[3]):
ret[:,:,:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,:,:,c, ...]) / norm
ret[:,:,:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,:,:,c, ...]-max_in) / norm
elif dim==4:
norm_sqr = torch.zeros_like(input[:,:,:,:,0, ...])
for c in range(sz[4]):
norm_sqr += weights[c] * (torch.exp(input[:,:,:,:,c, ...]))**2
norm_sqr += weights[c] * (torch.exp(input[:,:,:,:,c, ...]-max_in))**2
norm = torch.sqrt(norm_sqr)
for c in range(sz[4]):
ret[:,:,:,:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,:,:,:,c, ...]) / norm
ret[:,:,:,:,c, ...] = torch.sqrt(weights[c]) * torch.exp(input[:,:,:,:,c, ...]-max_in) / norm
else:
raise ValueError('weighted_softmax is only supported for dimensions 0, 1, 2, 3, and 4.')

Expand Down Expand Up @@ -1131,25 +1145,25 @@ def _compute_total_variation_1d(self, d):

# need to use torch.abs here to make sure the proper subgradient is computed at zero
batch_size = d.size()[0]
t0 = torch.abs(self.fdt.dXc(d))
t0 = torch.abs(self.fdt.dXf(d))

return (t0).sum()*self.volumeElement/batch_size

def _compute_total_variation_2d(self, d):

# need to use torch.norm here to make sure the proper subgradient is computed at zero
batch_size = d.size()[0]
t0 = torch.norm(torch.stack((self.fdt.dXc(d),self.fdt.dYc(d))),self.pnorm,0)
t0 = torch.norm(torch.stack((self.fdt.dXf(d),self.fdt.dYf(d))),self.pnorm,0)

return t0.sum()*self.volumeElement/batch_size

def _compute_total_variation_3d(self, d):

# need to use torch.norm here to make sure the proper subgradient is computed at zero
batch_size = d.size()[0]
t0 = torch.norm(torch.stack((self.fdt.dXc(d),
self.fdt.dYc(d),
self.fdt.dZc(d))), self.pnorm, 0)
t0 = torch.norm(torch.stack((self.fdt.dXf(d),
self.fdt.dYf(d),
self.fdt.dZf(d))), self.pnorm, 0)

return t0.sum()*self.volumeElement/batch_size

Expand Down Expand Up @@ -1190,6 +1204,24 @@ def get_current_penalty(self):
return self.current_penalty


def compute_local_weighted_tv_norm(self, I, weights):

sum_square_of_total_variation_penalty = Variable(MyTensor(self.nr_of_gaussians).zero_(), requires_grad=False)
# first compute the edge map
g_I = compute_localized_edge_penalty(I[:, 0, ...], self.spacing, self.params)
batch_size = I.size()[0]

# now computed weighted TV norm channel-by-channel, square it and then take the square root (this is like in color TV)
for g in range(self.nr_of_gaussians):
c_local_norm_grad = _compute_local_norm_of_gradient(weights[:, g, ...], self.spacing, self.pnorm)

to_sum = g_I * c_local_norm_grad * self.volumeElement / batch_size
current_tv = (to_sum).sum()
sum_square_of_total_variation_penalty[g] = current_tv**2

total_variation_penalty = torch.norm(sum_square_of_total_variation_penalty,p=2)
return total_variation_penalty

class encoder_block_2d(nn.Module):
def __init__(self, input_feature, output_feature, use_dropout, use_batch_normalization):
super(encoder_block_2d, self).__init__()
Expand Down Expand Up @@ -1480,13 +1512,7 @@ def forward(self, I, additional_inputs, global_multi_gaussian_weights, gaussian_
# compute the total variation penalty
total_variation_penalty = Variable(MyTensor(1).zero_(), requires_grad=False)
if self.total_variation_weight_penalty > 0:
# first compute the edge map
g_I = compute_localized_edge_penalty(I[:, 0, ...], self.spacing, self.params)
batch_size = I.size()[0]
for g in range(self.nr_of_gaussians):
# total_variation_penalty += self.compute_total_variation(weights[:,g,...])
c_local_norm_grad = _compute_local_norm_of_gradient(weights[:, g, ...], self.spacing, self.pnorm)
total_variation_penalty += (utils.remove_infs_from_variable(g_I * c_local_norm_grad)).sum() * self.volumeElement / batch_size
total_variation_penalty += self.compute_local_weighted_tv_norm(I, weights)

diffusion_penalty = Variable(MyTensor(1).zero_(), requires_grad=False)
if self.diffusion_weight_penalty > 0:
Expand Down Expand Up @@ -1832,13 +1858,7 @@ def forward(self, I, additional_inputs, global_multi_gaussian_weights, gaussian_
# compute the total variation penalty; compute this on the pre (non-smoothed) weights
total_variation_penalty = Variable(MyTensor(1).zero_(), requires_grad=False)
if self.total_variation_weight_penalty > 0:
# first compute the edge map
g_I = compute_localized_edge_penalty(I[:, 0, ...], self.spacing, self.params)
batch_size = I.size()[0]
for g in range(self.nr_of_gaussians):
# total_variation_penalty += self.compute_total_variation(weights[:,g,...])
c_local_norm_grad = _compute_local_norm_of_gradient(weights[:, g, ...], self.spacing, self.pnorm)
total_variation_penalty += (utils.remove_infs_from_variable(g_I * c_local_norm_grad)).sum() * self.volumeElement / batch_size
total_variation_penalty += self.compute_local_weighted_tv_norm(I=I,weights=weights)

diffusion_penalty = Variable(MyTensor(1).zero_(), requires_grad=False)
if self.diffusion_weight_penalty > 0:
Expand Down
44 changes: 43 additions & 1 deletion pyreg/finite_differences.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def lap(self, I):
else:
raise ValueError('Finite differences are only supported in dimensions 1 to 3')

def grad_norm_sqr(self, I):
def grad_norm_sqr_c(self, I):
"""
Computes the gradient norm of an image
!!!!!!!!!!!
Expand All @@ -200,6 +200,48 @@ def grad_norm_sqr(self, I):
else:
raise ValueError('Finite differences are only supported in dimensions 1 to 3')

def grad_norm_sqr_f(self, I):
"""
Computes the gradient norm of an image
!!!!!!!!!!!
IMPORTANT:
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
THIS IS FOR COMPUTATIONAL EFFICIENCY.
:param I: Input image [batch, X,Y,Z]
:return: returns ||grad I||^2
"""
ndim = self.getdimension(I)
if ndim == 1 + 1:
return self.dXf(I)**2
elif ndim == 2 + 1:
return (self.dXf(I)**2 + self.dYf(I)**2)
elif ndim == 3 + 1:
return (self.dXf(I)**2 + self.dYf(I)**2 + self.dZf(I)**2)
else:
raise ValueError('Finite differences are only supported in dimensions 1 to 3')

def grad_norm_sqr_b(self, I):
"""
Computes the gradient norm of an image
!!!!!!!!!!!
IMPORTANT:
ALL THE FOLLOWING IMPLEMENTED CODE ADD 1 ON DIMENSION, WHICH REPRESENT BATCH DIMENSION.
THIS IS FOR COMPUTATIONAL EFFICIENCY.
:param I: Input image [batch, X,Y,Z]
:return: returns ||grad I||^2
"""
ndim = self.getdimension(I)
if ndim == 1 + 1:
return self.dXb(I)**2
elif ndim == 2 + 1:
return (self.dXb(I)**2 + self.dYb(I)**2)
elif ndim == 3 + 1:
return (self.dXb(I)**2 + self.dYb(I)**2 + self.dZb(I)**2)
else:
raise ValueError('Finite differences are only supported in dimensions 1 to 3')

@abstractmethod
def getdimension(self,I):
"""
Expand Down
2 changes: 1 addition & 1 deletion pyreg/smoother_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def _do_CFL_clamping_if_necessary(self,v, clampCFL_dt):
:return: clampled velocity field
"""

rk4_factor = 2*np.sqrt(2)/self.dim*0.9 # 0.9 is saftey margin (see paper by Polzin et al. for this RK4 stability condition)
rk4_factor = 2*np.sqrt(2)/self.dim*0.75 # 0.75 is saftey margin (see paper by Polzin et al. for this RK4 stability condition)

if clampCFL_dt is not None:

Expand Down
3 changes: 3 additions & 0 deletions pyreg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
print('WARNING: nn_interpolation could not be imported (only supported in CUDA at the moment), some functionality may not be available.')


def my_hasnan(x):
return (x != x).any()

def create_symlink_with_correct_ext(sf,tf):
abs_s = os.path.abspath(sf)
ext_s = os.path.splitext(abs_s)[1]
Expand Down

0 comments on commit 6d81ccc

Please sign in to comment.