-
Notifications
You must be signed in to change notification settings - Fork 10
/
bilinear_sampler.py
82 lines (70 loc) · 3.23 KB
/
bilinear_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#
# Author : Alwyn Mathew
#
# Monodepth in pytorch(https://github.com/alwynmathew/monodepth-pytorch)
# Bilinear sampler in pytorch(https://github.com/alwynmathew/bilinear-sampler-pytorch)
#
from __future__ import absolute_import, division, print_function
import torch
from torch.nn.functional import pad
from main_options import MainOptions
opt = MainOptions().parse()
def apply_disparity(input_images, x_offset, wrap_mode='border', tensor_type = 'torch.cuda.FloatTensor'):
num_batch, num_channels, height, width = input_images.size()
# Handle both texture border types
edge_size = 0
if wrap_mode == 'border':
edge_size = 1
# Pad last and second-to-last dimensions by 1 from both sides
input_images = pad(input_images, (1, 1, 1, 1))
elif wrap_mode == 'edge':
edge_size = 0
else:
return None
# Put channels to slowest dimension and flatten batch with respect to others
input_images = input_images.permute(1, 0, 2, 3).contiguous()
im_flat = input_images.view(num_channels, -1)
# Create meshgrid for pixel indicies (PyTorch doesn't have dedicated
# meshgrid function)
x = torch.linspace(0, width - 1, width).repeat(height, 1).type(tensor_type).to(opt.gpu_ids)
y = torch.linspace(0, height - 1, height).repeat(width, 1).transpose(0, 1).type(tensor_type).to(opt.gpu_ids)
# Take padding into account
x = x + edge_size
y = y + edge_size
# Flatten and repeat for each image in the batch
x = x.view(-1).repeat(1, num_batch)
y = y.view(-1).repeat(1, num_batch)
# Now we want to sample pixels with indicies shifted by disparity in X direction
# For that we convert disparity from % to pixels and add to X indicies
x = x + x_offset.contiguous().view(-1) * width
# Make sure we don't go outside of image
x = torch.clamp(x, 0.0, width - 1 + 2 * edge_size)
# Round disparity to sample from integer-valued pixel grid
y0 = torch.floor(y)
# In X direction round both down and up to apply linear interpolation
# between them later
x0 = torch.floor(x)
x1 = x0 + 1
# After rounding up we might go outside the image boundaries again
x1 = x1.clamp(max=(width - 1 + 2 * edge_size))
# Calculate indices to draw from flattened version of image batch
dim2 = (width + 2 * edge_size)
dim1 = (width + 2 * edge_size) * (height + 2 * edge_size)
# Set offsets for each image in the batch
base = dim1 * torch.arange(num_batch).type(tensor_type).to(opt.gpu_ids)
base = base.view(-1, 1).repeat(1, height * width).view(-1)
# One pixel shift in Y direction equals dim2 shift in flattened array
base_y0 = base + y0 * dim2
# Add two versions of shifts in X direction separately
idx_l = base_y0 + x0
idx_r = base_y0 + x1
# Sample pixels from images
pix_l = im_flat.gather(1, idx_l.repeat(num_channels, 1).long())
pix_r = im_flat.gather(1, idx_r.repeat(num_channels, 1).long())
# Apply linear interpolation to account for fractional offsets
weight_l = x1 - x
weight_r = x - x0
output = weight_l * pix_l + weight_r * pix_r
# Reshape back into image batch and permute back to (N,C,H,W) shape
output = output.view(num_channels, num_batch, height, width).permute(1,0,2,3)
return output