Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tensor size >4G support for bwd/wrw #80

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion driver/conv_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ int main(int argc, char **argv) {
// gen rand
gen_rand_vector<float, float>(host_output, static_cast<size_t>(n) * k * ho * wo, 0.0, 1.0);
gen_rand_vector<float, float>(host_weight, static_cast<size_t>(k) * c * y * x, -0.5, 0.5);
gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number
//gen_rand_vector<float, float>(host_input, static_cast<size_t>(n) * c * hi * wi, 999999., 9999999.); // manually input value to a very large number
memset(host_input, 0xfc, static_cast<size_t>(k) * c * y * x * sizeof(float));
// gen_rand_vector<float, int>(host_output, static_cast<size_t>(n) * k * ho * wo,1, 1);
// gen_rand_vector<float, int>(host_weight, static_cast<size_t>(k) * c * y * x, 1, 1);
#ifdef USE_GPU_NAIVE_CONV
Expand Down
32 changes: 26 additions & 6 deletions driver/igemm_bwd_gtc_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ class igemm_bwd_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
Expand Down Expand Up @@ -340,9 +347,17 @@ class igemm_bwd_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
if(splits == 0){
printf("image size (c*h*w or k*h*w) is bigger than 4g, which is not supported now\n");
return false;
}
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int gemm_k_per_block = tunable->gemm_k_per_block;
Expand Down Expand Up @@ -456,9 +471,14 @@ class igemm_bwd_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int gemm_k_per_block = tunable->gemm_k_per_block;
Expand Down Expand Up @@ -622,7 +642,7 @@ class igemm_bwd_gtc_t {
hipEvent_t stop;
hipEventCreate(&start);
hipEventCreate(&stop);
HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1,
HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, splits, 1,
u_block_size, 1, 1, 0, 0, NULL,
(void **)&config, start, stop));
hipEventSynchronize(stop);
Expand Down Expand Up @@ -670,7 +690,7 @@ class igemm_bwd_gtc_t {
hipEventCreate(&start);
hipEventCreate(&stop);
// for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1,
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config, start, stop));
hipEventSynchronize(stop);
Expand All @@ -680,7 +700,7 @@ class igemm_bwd_gtc_t {
#else
gpu_timer_t timer(NULL);
timer.start();
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1,
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config));
timer.stop();
Expand All @@ -703,7 +723,7 @@ class igemm_bwd_gtc_t {
hipEvent_t stop;
hipEventCreate(&start);
hipEventCreate(&stop);
HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, 1, 1,
HIP_CALL(hipHccModuleLaunchKernel(upsampling_clear_kernel_func, u_grid_size * u_block_size, splits, 1,
u_block_size, 1, 1, 0, 0, NULL,
(void **)&config, start, stop));
hipEventSynchronize(stop);
Expand Down Expand Up @@ -732,7 +752,7 @@ class igemm_bwd_gtc_t {
hipEventCreate(&start);
hipEventCreate(&stop);
// for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1,
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config, start, stop));
hipEventSynchronize(stop);
Expand All @@ -742,7 +762,7 @@ class igemm_bwd_gtc_t {
#else
gpu_timer_t timer(NULL);
timer.start();
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1,
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config));
timer.stop();
Expand Down
62 changes: 10 additions & 52 deletions driver/igemm_fwd_gtc_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ class igemm_fwd_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

int splits = split_batch_size(arg, tunable);
int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
Expand All @@ -160,54 +162,6 @@ class igemm_fwd_gtc_t {
return grid_size;
}

// this is to support big tensor > 4G. need to decide how many splits needed
// return the number of splits
int split_batch_size(const args_t *arg, const igemm_gtc_tunable_t *tunable)
{
int hi = arg->get_int("in_h");
int wi = arg->get_int("in_w");
int n = arg->get_int("batchsize");
int k = arg->get_int("out_channels");
int c = arg->get_int("in_channels");

int stride_h = arg->get_int("conv_stride_h");
int stride_w = arg->get_int("conv_stride_w");
int dilation_h = arg->get_int("dilation_h");
int dilation_w = arg->get_int("dilation_w");
int pad_h = arg->get_int("pad_h");
int pad_w = arg->get_int("pad_w");
int y = arg->get_int("fil_h");
int x = arg->get_int("fil_w");
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);

int data_byte = utility_string_to_data_byte(tunable->precision);
size_t image_size_input = static_cast<size_t>(c) * hi * wi * data_byte;
size_t image_size_output = static_cast<size_t>(k) * ho * wo * data_byte;
size_t size_4g = 0xffffffffUL;
if(image_size_input >= size_4g || image_size_output >= size_4g)
return 0;

size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output;
size_t splited_n = size_4g / image_size;

// round up splits, we must match
// 1. splited_n * image_size < size_4g
// 2. n % splited_n == 0
// if(splited_n >= n)
// return 1;
assert(splited_n != 0);
while(splited_n >= 1){
// printf("n:%d, splited_n:%d\n", n, splited_n);
if(n % splited_n == 0)
break;
splited_n--;
}

assert(splited_n * image_size < size_4g && n % splited_n == 0);
return n / splited_n;
}

bool tunable_is_valid(const args_t *arg,
const igemm_gtc_tunable_t *tunable)
{
Expand All @@ -228,15 +182,17 @@ class igemm_fwd_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

assert(c % group == 0 && k % group == 0);

int splits = split_batch_size(arg, tunable);
int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
if(splits == 0){
printf("image size (c*h*w) is bigger than 4g, which is not supported now\n");
printf("image size (c*h*w or k*h*w) is bigger than 4g, which is not supported now\n");
return false;
}
n = n/splits; // split batch size here


int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
Expand Down Expand Up @@ -328,10 +284,12 @@ class igemm_fwd_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

assert(c % group == 0 && k % group == 0);

int splits = split_batch_size(arg, tunable);
int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
Expand Down
31 changes: 31 additions & 0 deletions driver/igemm_gtc_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,4 +310,35 @@ igemm_gtc_encode_kernel_name(const igemm_gtc_tunable_t *tunable) {
return kernel_name;
}

// this is to support big tensor > 4G. need to decide how many splits needed
// return the number of splits, valid for nchw, nhwc, 2d/3d conv
int igemm_split_batch_size(int n, int wi, int hi, int di, int c, int k, int wo, int ho, int do_, int data_byte)
{
size_t image_size_input = static_cast<size_t>(c) * di * hi * wi * data_byte;
size_t image_size_output = static_cast<size_t>(k) * do_ * ho * wo * data_byte;
size_t size_4g = 0xffffffffUL;
if(image_size_input >= size_4g || image_size_output >= size_4g)
return 0;

size_t image_size = image_size_input >= image_size_output ? image_size_input : image_size_output;
size_t splited_n = size_4g / image_size;

// round up splits, we must match
// 1. splited_n * image_size < size_4g
// 2. n % splited_n == 0
assert(splited_n != 0);

if(splited_n >= n)
return 1; // speed up following while loop

while(splited_n >= 1){
if(n % splited_n == 0)
break;
splited_n--;
}

assert(splited_n * image_size < size_4g && n % splited_n == 0);
return n / splited_n;
}

#endif
48 changes: 44 additions & 4 deletions driver/igemm_wrw_gtc_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ class igemm_wrw_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n is never used?


int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
Expand Down Expand Up @@ -245,8 +250,16 @@ class igemm_wrw_gtc_t {
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int b = tunable->nxe == 0 ? (ho * wo) : ((ho * wo + tunable->nxb - 1) / tunable->nxb) * tunable->nxb; // pad to nxb modulo when nxe != 0
int data_byte = utility_string_to_data_byte(tunable->precision);
assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
if(splits == 0){
printf("image size (c*h*w or k*h*w) is bigger than 4g, which is not supported now\n");
return false;
}
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int gemm_k_per_block = tunable->gemm_k_per_block;
Expand All @@ -266,6 +279,12 @@ class igemm_wrw_gtc_t {
int nxe = tunable->nxe == 0 ? 1 : tunable->nxe;
bool unit_conv = (x==1)&&(y==1)&&(stride_h==1)&&(stride_w==1)&&(dilation_h==1)&&(dilation_w==1)&&(pad_h==0)&&(pad_w==0);

if(splits > 1 && gemm_k_global_split == 0)
{
// large tensor can only used for gkgs kernel
return false;
}

if(((c / group) % (gemm_n_per_block / nxe) != 0) || (((x * y) % nxe) != 0))
{
return false;
Expand Down Expand Up @@ -339,8 +358,13 @@ class igemm_wrw_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);
assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int b = tunable->nxe == 0 ? (ho * wo) : ((ho * wo + tunable->nxb - 1) / tunable->nxb) * tunable->nxb;

int gemm_m_per_block = tunable->gemm_m_per_block;
Expand Down Expand Up @@ -392,7 +416,8 @@ class igemm_wrw_gtc_t {
static int if_gemm_k_global_split(const args_t *arg,
const int gemm_m_per_block,
const int gemm_n_per_block,
const int gemm_k_per_block)
const int gemm_k_per_block,
const int data_byte)
{
int gemm_k_global_split = 0;
int hi = arg->get_int("in_h");
Expand All @@ -414,6 +439,10 @@ class igemm_wrw_gtc_t {
int group = arg->get_int("group_count");
assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

// int b = tunable->nxe == 0 ? (ho * wo) : ((ho * wo + tunable->nxb - 1) / tunable->nxb) * tunable->nxb;
int gemm_m = k / group;
int gemm_n = (c / group) * y * x;
Expand Down Expand Up @@ -472,8 +501,13 @@ class igemm_wrw_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunables[0].precision);
assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int gemm_m_per_block = 0;
int gemm_n_per_block = 0;
int gemm_k_per_block = 0;
Expand Down Expand Up @@ -541,7 +575,8 @@ class igemm_wrw_gtc_t {
gemm_k_global_split = if_gemm_k_global_split(arg,
gemm_m_per_block,
gemm_n_per_block,
gemm_k_per_block);
gemm_k_per_block,
data_byte);

nxb = 1;
nxe = 1;
Expand Down Expand Up @@ -660,8 +695,13 @@ class igemm_wrw_gtc_t {
int ho = conv_out_size(hi, pad_h, dilation_h, y, stride_h);
int wo = conv_out_size(wi, pad_w, dilation_w, x, stride_w);
int group = arg->get_int("group_count");
int data_byte = utility_string_to_data_byte(tunable->precision);
assert(c % group == 0 && k % group == 0);

int splits = igemm_split_batch_size(n, wi, hi, 1, c, k, wo, ho, 1, data_byte);
assert(splits != 0);
n = n/splits; // split batch size here

int gemm_m_per_block = tunable->gemm_m_per_block;
int gemm_n_per_block = tunable->gemm_n_per_block;
int gemm_k_per_block = tunable->gemm_k_per_block;
Expand Down Expand Up @@ -729,7 +769,7 @@ class igemm_wrw_gtc_t {
hipEventCreate(&stop);

// for hipHccModuleLaunchKernel/hipExtModuleLaunchKernel, the grid_size is in unit of workitem
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, 1, 1,
HIP_CALL(hipHccModuleLaunchKernel(kernel_func, grid_size * block_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config, start, stop));

Expand All @@ -741,7 +781,7 @@ class igemm_wrw_gtc_t {
gpu_timer_t timer(NULL);
timer.start();

HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, 1, 1,
HIP_CALL(hipModuleLaunchKernel(kernel_func, grid_size, splits, 1,
block_size, 1, 1, 0, 0, NULL,
(void **)&config));

Expand Down
Loading