Skip to content

Commit cf5a3eb

Browse files
authored
Merge pull request #77 from mkolod/python36
Make FlowNet2 work with PyTorch 0.4.1
2 parents 3d6db9f + 12f794c commit cf5a3eb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1204
-1166
lines changed

install.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/bin/bash
22
cd ./networks/correlation_package
3-
./make.sh
3+
python setup.py install
44
cd ../resample2d_package
5-
./make.sh
5+
python setup.py install
66
cd ../channelnorm_package
7-
./make.sh
7+
python setup.py install
88
cd ..

models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import math
66
import numpy as np
77

8-
from networks.resample2d_package.modules.resample2d import Resample2d
9-
from networks.channelnorm_package.modules.channelnorm import ChannelNorm
8+
from networks.resample2d_package.resample2d import Resample2d
9+
from networks.channelnorm_package.channelnorm import ChannelNorm
1010

1111
from networks import FlowNetC
1212
from networks import FlowNetS

networks/FlowNetC.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import numpy as np
77

8-
from .correlation_package.modules.correlation import Correlation
8+
from .correlation_package.correlation import Correlation
99

1010
from .submodules import *
1111
'Parameter count , 39,175,298 '

networks/channelnorm_package/build.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

networks/channelnorm_package/functions/channelnorm.py renamed to networks/channelnorm_package/channelnorm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torch.autograd import Function, Variable
2-
from .._ext import channelnorm
3-
2+
from torch.nn.modules.module import Module
3+
import channelnorm_cuda
44

55
class ChannelNormFunction(Function):
66

@@ -10,7 +10,7 @@ def forward(ctx, input1, norm_deg=2):
1010
b, _, h, w = input1.size()
1111
output = input1.new(b, 1, h, w).zero_()
1212

13-
channelnorm.ChannelNorm_cuda_forward(input1, output, norm_deg)
13+
channelnorm_cuda.forward(input1, output, norm_deg)
1414
ctx.save_for_backward(input1, output)
1515
ctx.norm_deg = norm_deg
1616

@@ -22,7 +22,18 @@ def backward(ctx, grad_output):
2222

2323
grad_input1 = Variable(input1.new(input1.size()).zero_())
2424

25-
channelnorm.ChannelNorm_cuda_backward(input1, output, grad_output.data,
25+
channelnorm.backward(input1, output, grad_output.data,
2626
grad_input1.data, ctx.norm_deg)
2727

2828
return grad_input1, None
29+
30+
31+
class ChannelNorm(Module):
32+
33+
def __init__(self, norm_deg=2):
34+
super(ChannelNorm, self).__init__()
35+
self.norm_deg = norm_deg
36+
37+
def forward(self, input1):
38+
return ChannelNormFunction.apply(input1, self.norm_deg)
39+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <torch/torch.h>
2+
#include <ATen/ATen.h>
3+
4+
#include "channelnorm_kernel.cuh"
5+
6+
int channelnorm_cuda_forward(
7+
at::Tensor& input1,
8+
at::Tensor& output,
9+
int norm_deg) {
10+
11+
channelnorm_kernel_forward(input1, output, norm_deg);
12+
return 1;
13+
}
14+
15+
16+
int channelnorm_cuda_backward(
17+
at::Tensor& input1,
18+
at::Tensor& output,
19+
at::Tensor& gradOutput,
20+
at::Tensor& gradInput1,
21+
int norm_deg) {
22+
23+
channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
24+
return 1;
25+
}
26+
27+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
28+
m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
29+
m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
30+
}
31+
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Context.h>
3+
4+
#include "channelnorm_kernel.cuh"
5+
6+
#define CUDA_NUM_THREADS 512
7+
8+
#define DIM0(TENSOR) ((TENSOR).x)
9+
#define DIM1(TENSOR) ((TENSOR).y)
10+
#define DIM2(TENSOR) ((TENSOR).z)
11+
#define DIM3(TENSOR) ((TENSOR).w)
12+
13+
#define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
14+
15+
using at::Half;
16+
17+
template <typename scalar_t>
18+
__global__ void kernel_channelnorm_update_output(
19+
const int n,
20+
const scalar_t* __restrict__ input1,
21+
const long4 input1_size,
22+
const long4 input1_stride,
23+
scalar_t* __restrict__ output,
24+
const long4 output_size,
25+
const long4 output_stride,
26+
int norm_deg) {
27+
28+
int index = blockIdx.x * blockDim.x + threadIdx.x;
29+
30+
if (index >= n) {
31+
return;
32+
}
33+
34+
int dim_b = DIM0(output_size);
35+
int dim_c = DIM1(output_size);
36+
int dim_h = DIM2(output_size);
37+
int dim_w = DIM3(output_size);
38+
int dim_chw = dim_c * dim_h * dim_w;
39+
40+
int b = ( index / dim_chw ) % dim_b;
41+
int y = ( index / dim_w ) % dim_h;
42+
int x = ( index ) % dim_w;
43+
44+
int i1dim_c = DIM1(input1_size);
45+
int i1dim_h = DIM2(input1_size);
46+
int i1dim_w = DIM3(input1_size);
47+
int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
48+
int i1dim_hw = i1dim_h * i1dim_w;
49+
50+
float result = 0.0;
51+
52+
for (int c = 0; c < i1dim_c; ++c) {
53+
int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
54+
scalar_t val = input1[i1Index];
55+
result += static_cast<float>(val * val);
56+
}
57+
result = sqrt(result);
58+
output[index] = static_cast<scalar_t>(result);
59+
}
60+
61+
62+
template <typename scalar_t>
63+
__global__ void kernel_channelnorm_backward_input1(
64+
const int n,
65+
const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
66+
const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride,
67+
const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
68+
scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride,
69+
int norm_deg) {
70+
71+
int index = blockIdx.x * blockDim.x + threadIdx.x;
72+
73+
if (index >= n) {
74+
return;
75+
}
76+
77+
float val = 0.0;
78+
79+
int dim_b = DIM0(gradInput_size);
80+
int dim_c = DIM1(gradInput_size);
81+
int dim_h = DIM2(gradInput_size);
82+
int dim_w = DIM3(gradInput_size);
83+
int dim_chw = dim_c * dim_h * dim_w;
84+
int dim_hw = dim_h * dim_w;
85+
86+
int b = ( index / dim_chw ) % dim_b;
87+
int y = ( index / dim_w ) % dim_h;
88+
int x = ( index ) % dim_w;
89+
90+
91+
int outIndex = b * dim_hw + y * dim_w + x;
92+
val = static_cast<float>(gradOutput[outIndex]) * static_cast<float>(input1[index]) / (static_cast<float>(output[outIndex])+1e-9);
93+
gradInput[index] = static_cast<scalar_t>(val);
94+
95+
}
96+
97+
void channelnorm_kernel_forward(
98+
at::Tensor& input1,
99+
at::Tensor& output,
100+
int norm_deg) {
101+
102+
const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
103+
const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
104+
105+
const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
106+
const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
107+
108+
int n = output.numel();
109+
110+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {
111+
112+
kernel_channelnorm_update_output<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
113+
n,
114+
input1.data<scalar_t>(),
115+
input1_size,
116+
input1_stride,
117+
output.data<scalar_t>(),
118+
output_size,
119+
output_stride,
120+
norm_deg);
121+
122+
}));
123+
124+
// TODO: ATen-equivalent check
125+
126+
// THCudaCheck(cudaGetLastError());
127+
}
128+
129+
void channelnorm_kernel_backward(
130+
at::Tensor& input1,
131+
at::Tensor& output,
132+
at::Tensor& gradOutput,
133+
at::Tensor& gradInput1,
134+
int norm_deg) {
135+
136+
const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
137+
const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
138+
139+
const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
140+
const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
141+
142+
const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
143+
const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
144+
145+
const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
146+
const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
147+
148+
int n = gradInput1.numel();
149+
150+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {
151+
152+
kernel_channelnorm_backward_input1<scalar_t><<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
153+
n,
154+
input1.data<scalar_t>(),
155+
input1_size,
156+
input1_stride,
157+
output.data<scalar_t>(),
158+
output_size,
159+
output_stride,
160+
gradOutput.data<scalar_t>(),
161+
gradOutput_size,
162+
gradOutput_stride,
163+
gradInput1.data<scalar_t>(),
164+
gradInput1_size,
165+
gradInput1_stride,
166+
norm_deg
167+
);
168+
169+
}));
170+
171+
// TODO: Add ATen-equivalent check
172+
173+
// THCudaCheck(cudaGetLastError());
174+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
void channelnorm_kernel_forward(
6+
at::Tensor& input1,
7+
at::Tensor& output,
8+
int norm_deg);
9+
10+
11+
void channelnorm_kernel_backward(
12+
at::Tensor& input1,
13+
at::Tensor& output,
14+
at::Tensor& gradOutput,
15+
at::Tensor& gradInput1,
16+
int norm_deg);

networks/channelnorm_package/functions/__init__.py

Whitespace-only changes.

networks/channelnorm_package/make.sh

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)