Skip to content

Commit

Permalink
Merge pull request #125 from uncbiag/new_kernel
Browse files Browse the repository at this point in the history
New kernel
  • Loading branch information
marcniethammer authored Jul 7, 2018
2 parents 47da87f + e5ecdf2 commit c6d4662
Show file tree
Hide file tree
Showing 13 changed files with 1,476 additions and 530 deletions.
103 changes: 81 additions & 22 deletions demos/create_synthetic_regularization_test_cases.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import print_function
from future.utils import native_str

from builtins import str
from builtins import range

import set_pyreg_paths

import matplotlib as matplt
from pyreg.config_parser import MATPLOTLIB_AGG
if MATPLOTLIB_AGG:
matplt.use('Agg')
#from pyreg.config_parser import MATPLOTLIB_AGG
#if MATPLOTLIB_AGG:
# matplt.use('Agg')

import matplotlib.pyplot as plt
import scipy.ndimage as ndimage
Expand All @@ -16,6 +18,7 @@
import pyreg.finite_differences as fd
import pyreg.custom_pytorch_extensions as ce
import pyreg.smoother_factory as sf
import pyreg.deep_smoothers as ds

from pyreg.data_wrapper import AdaptVal

Expand Down Expand Up @@ -326,24 +329,63 @@ def _compute_ring_radii(extent, nr_of_rings, randomize_radii, randomize_factor=0

return rings_at

def compute_localized_velocity_from_momentum(m,weights,multi_gaussian_stds,sz,spacing,visualize=False):
def compute_localized_velocity_from_momentum(m,weights,multi_gaussian_stds,sz,spacing,kernel_weighting_type='w_K',visualize=False):

nr_of_gaussians = len(multi_gaussian_stds)
# create a velocity field from this momentum using a multi-Gaussian kernel
gaussian_fourier_filter_generator = ce.GaussianFourierFilterGenerator(sz[2:], spacing, nr_of_gaussians)

vs = ce.fourier_set_of_gaussian_convolutions(AdaptVal(Variable(torch.from_numpy(m), requires_grad=False)),
gaussian_fourier_filter_generator,
sigma=AdaptVal(Variable(torch.from_numpy(multi_gaussian_stds),
requires_grad=False)),
compute_std_gradients=False)

gaussian_fourier_filter_generator = ce.GaussianFourierFilterGenerator(sz[2:], spacing, nr_of_slots=nr_of_gaussians)

t_weights = Variable(torch.from_numpy(weights),requires_grad=False)
t_momentum = Variable(torch.from_numpy(m),requires_grad=False)

if kernel_weighting_type=='sqrt_w_K_sqrt_w':
sqrt_weights = torch.sqrt(t_weights)
sqrt_weighted_multi_smooth_v = ds.compute_weighted_multi_smooth_v(momentum=t_momentum, weights=sqrt_weights,
gaussian_stds=multi_gaussian_stds,
gaussian_fourier_filter_generator=gaussian_fourier_filter_generator)
elif kernel_weighting_type=='w_K_w':
# now create the weighted multi-smooth-v
weighted_multi_smooth_v = ds.compute_weighted_multi_smooth_v(momentum=t_momentum, weights=t_weights,
gaussian_stds=multi_gaussian_stds,
gaussian_fourier_filter_generator=gaussian_fourier_filter_generator)
elif kernel_weighting_type=='w_K':
multi_smooth_v = ce.fourier_set_of_gaussian_convolutions(t_momentum,
gaussian_fourier_filter_generator=gaussian_fourier_filter_generator,
sigma=Variable(torch.from_numpy(multi_gaussian_stds),requires_grad=False),
compute_std_gradients=False)

# now compute the localized_velocity
# compute velocity based on localized weights
localized_v = np.zeros([1, 2] + sz[2:], dtype='float32')
dims = localized_v.shape[1]
for g in range(nr_of_gaussians):
for d in range(dims):
localized_v[0, d, ...] += weights[0, g, ...] * vs[g, 0, d, ...].data.cpu().numpy()

# now we apply this weight across all the channels; weight output is B x weights x X x Y
for n in range(dims):
# reverse the order so that for a given channel we have batch x multi_velocity x X x Y
# i.e., the multi-velocity field output is treated as a channel
# reminder: # format of multi_smooth_v is multi_v x batch x channels x X x Y
# (channels here are the vector field components); i.e. as many as there are dimensions
# each one of those should be smoothed the same

# let's smooth this on the fly, as the smoothing will be of form
# w_i*K_i*(w_i m)

if kernel_weighting_type=='sqrt_w_K_sqrt_w':
# roc should be: batch x multi_v x X x Y
roc = torch.transpose(sqrt_weighted_multi_smooth_v[:, :, n, ...], 0, 1)
yc = torch.sum(roc * sqrt_weights, dim=1)
elif kernel_weighting_type=='w_K_w':
# roc should be: batch x multi_v x X x Y
roc = torch.transpose(weighted_multi_smooth_v[:, :, n, ...], 0, 1)
yc = torch.sum(roc * t_weights, dim=1)
elif kernel_weighting_type=='w_K':
# roc should be: batch x multi_v x X x Y
roc = torch.transpose(multi_smooth_v[:, :, n, ...], 0, 1)
yc = torch.sum(roc * t_weights, dim=1)
else:
raise ValueError('Unknown kernel type: {}'.format(kernel_weighting_type))

localized_v[:, n, ...] = yc.data.cpu().numpy() # ret is: batch x channels x X x Y

if visualize:

Expand Down Expand Up @@ -382,21 +424,22 @@ def compute_map_from_v(localized_v,sz,spacing):

def add_texture(im_orig):
sz = im_orig.shape
rand_noise = np.random.random(sz[2:])
rand_noise = np.random.random(sz[2:])-0.5
rand_noise = rand_noise.view().reshape(sz)
r_params = pars.ParameterDict()
r_params['smoother']['type'] = 'gaussian'
r_params['smoother']['gaussian_std'] = 0.015
r_params['smoother']['gaussian_std'] = 0.02
s_r = sf.SmootherFactory(sz[2::], spacing).create_smoother(r_params)

rand_noise_smoothed = s_r.smooth(AdaptVal(Variable(torch.from_numpy(rand_noise), requires_grad=False))).data.cpu().numpy()
rand_noise_smoothed /= rand_noise_smoothed.max()
rand_noise_smoothed /= 2*rand_noise_smoothed.max()

im = im_orig + rand_noise_smoothed

return im

def create_random_image_pair(weights_not_fluid,weights_fluid,weights_neutral,weight_smoothing_std,multi_gaussian_stds,
kernel_weighting_type,
randomize_momentum_on_circle,randomize_in_sectors,
put_weights_between_circles,
start_with_fluid_weight,
Expand Down Expand Up @@ -451,6 +494,8 @@ def create_random_image_pair(weights_not_fluid,weights_fluid,weights_neutral,wei
#weights_old = np.zeros_like(weights_orig)
#weights_old[:] = weights_orig
weights_orig = (smoother.smooth(Variable(torch.from_numpy(weights_orig),requires_grad=False))).data.cpu().numpy()
# make sure they are strictly positive
weights_orig[weights_orig<0] = 0

if publication_figures_directory is not None:
plt.clf()
Expand All @@ -477,7 +522,7 @@ def create_random_image_pair(weights_not_fluid,weights_fluid,weights_neutral,wei
publication_prefix='circle_init',
image_pair_nr=image_pair_nr)

localized_v_orig = compute_localized_velocity_from_momentum(m=m_orig,weights=weights_orig,multi_gaussian_stds=multi_gaussian_stds,sz=sz,spacing=spacing)
localized_v_orig = compute_localized_velocity_from_momentum(m=m_orig,weights=weights_orig,multi_gaussian_stds=multi_gaussian_stds,sz=sz,spacing=spacing,kernel_weighting_type=kernel_weighting_type)

if publication_figures_directory is not None:
plt.clf()
Expand All @@ -497,6 +542,15 @@ def create_random_image_pair(weights_not_fluid,weights_fluid,weights_neutral,wei
plt.axis('off')
plt.savefig(os.path.join(publication_figures_directory, 'ring_im_orig_textured_{:d}.pdf'.format(image_pair_nr)),bbox_inches='tight',pad_inches=0)

# plt.clf()
# plt.subplot(1,2,1)
# plt.imshow(ring_im[0,0,...],clim=(-0.5,2.5))
# plt.colorbar()
# plt.subplot(1,2,2)
# plt.imshow(ring_im_orig[0, 0, ...], clim=(-0.5, 2.5))
# plt.colorbar()
# plt.show()

else:
ring_im = ring_im_orig

Expand All @@ -520,6 +574,8 @@ def create_random_image_pair(weights_not_fluid,weights_fluid,weights_neutral,wei
id_c_warped = id_c_warped_t.data.cpu().numpy()
weights_warped_t = utils.compute_warped_image_multiNC(AdaptVal(Variable(torch.from_numpy(weights_orig),requires_grad=False)), phi1_orig, spacing, spline_order=1)
weights_warped = weights_warped_t.data.cpu().numpy()
# make sure they are stirctly positive
weights_warped[weights_warped<0] = 0

warped_source_im_orig = I1_label_orig.data.cpu().numpy()

Expand All @@ -536,7 +592,7 @@ def create_random_image_pair(weights_not_fluid,weights_fluid,weights_neutral,wei

localized_v_warped = compute_localized_velocity_from_momentum(m=m_warped_source, weights=weights_warped,
multi_gaussian_stds=multi_gaussian_stds, sz=sz,
spacing=spacing)
spacing=spacing,kernel_weighting_type=kernel_weighting_type)

if publication_figures_directory is not None:
plt.clf()
Expand Down Expand Up @@ -714,6 +770,8 @@ def get_parameter_value_flag(command_line_par,params, params_name, default_val,
parser.add_argument('--weights_fluid', required=False,type=str, default=None, help='weights for a fluid circle; default=[0.2,0.5,0.2,0.1]')
parser.add_argument('--weights_background', required=False,type=str, default=None, help='weights for the background; default=[0,0,0,1]')

parser.add_argument('--kernel_weighting_type', required=False, type=str, default='w_K', help='Which kernel weighting to use for integration. Specify as [w_K|w_K_w|sqrt_w_K_sqrt_w]; w_K is the default')

parser.add_argument('--nr_of_angles', required=False, default=None, type=int, help='number of angles for randomize in sector') #10
parser.add_argument('--multiplier_factor', required=False, default=None, type=float, help='value the random momentum is multiplied by') #1.0
parser.add_argument('--momentum_smoothing', required=False, default=None, type=int, help='how much the randomly generated momentum is smoothed') #0.05
Expand All @@ -738,7 +796,7 @@ def get_parameter_value_flag(command_line_par,params, params_name, default_val,
nr_of_pairs_to_generate = args.nr_of_pairs_to_generate

nr_of_circles_to_generate = get_parameter_value(args.nr_of_circles_to_generate, params,'nr_of_circles_to_generate', 2, 'number of circles for the synthetic data')
circle_extent = get_parameter_value(args.circle_extent, params, 'circle_extent', 0.02, 'Size of largest circle; image is [-0.5,0.5]^2')
circle_extent = get_parameter_value(args.circle_extent, params, 'circle_extent', 0.2, 'Size of largest circle; image is [-0.5,0.5]^2')

randomize_momentum_on_circle = get_parameter_value_flag(not args.do_not_randomize_momentum,params=params, params_name='randomize_momentum_on_circle',
default_val=True, params_description='randomizes the momentum on the circles')
Expand Down Expand Up @@ -818,7 +876,7 @@ def get_parameter_value_flag(command_line_par,params, params_name, default_val,
sz = [1, 1, sz[0], sz[1]]
spacing = 1.0 / (np.array(sz[2:]) - 1)

output_dir = args.output_directory
output_dir = os.path.normpath(args.output_directory)+'_kernel_weighting_type_' + native_str(args.kernel_weighting_type)

image_output_dir = os.path.join(output_dir,'brain_affine_icbm')
label_output_dir = os.path.join(output_dir,'label_affine_icbm')
Expand Down Expand Up @@ -880,6 +938,7 @@ def get_parameter_value_flag(command_line_par,params, params_name, default_val,
weights_neutral=weights_neutral,
weight_smoothing_std=args.weight_smoothing_std,
multi_gaussian_stds=multi_gaussian_stds,
kernel_weighting_type=args.kernel_weighting_type,
randomize_momentum_on_circle=randomize_momentum_on_circle,
randomize_in_sectors=randomize_in_sectors,
put_weights_between_circles=put_weights_between_circles,
Expand Down
177 changes: 0 additions & 177 deletions demos/histogram_computations.py

This file was deleted.

Loading

0 comments on commit c6d4662

Please sign in to comment.