Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BN] Enable NHWC in OCL #3399

Open
wants to merge 13 commits into
base: develop
Choose a base branch
from
69 changes: 38 additions & 31 deletions driver/bn_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ int BatchNormDriver<TInput, Tref, TAcc, TScaleBias, TOut>::GetandSetData()
SetBNParametersFromCmdLineArgs();

in.AllocOnHost(tensor<TInput>{bn_layout, in_len});
for(size_t i = 0; i < in.GetVector().size(); i++)
{
in.GetVector()[i] = prng::gen_canonical<TInput>();
}
// 0.0 to 2.0 (since unsigned)
in.GetTensor().generate([](auto...) {
return prng::gen_descreet_unsigned<TInput>(2e-3 /*scale*/, 1000 /*range*/);
});

auto derivedBnDesc = miopen::TensorDescriptor{};
miopen::DeriveBNTensorDescriptor(derivedBnDesc, in.GetTensor().desc, bn_mode);
Expand All @@ -208,20 +208,25 @@ int BatchNormDriver<TInput, Tref, TAcc, TScaleBias, TOut>::GetandSetData()
out.AllocOnHost(tensor<TInput>{bn_layout, in_len});
scale.AllocOnHost(tensor<TScaleBias>{bn_layout, derivedBnDesc.GetLengths()});
bias.AllocOnHost(tensor<TScaleBias>{bn_layout, derivedBnDesc.GetLengths()});

for(int i = 0; i < scale.GetVector().size(); i++)
{
scale.GetVector()[i] = prng::gen_canonical<TInput>();
bias.GetVector()[i] = prng::gen_canonical<TInput>();
}
// -2.0 to 2.0
scale.GetTensor().generate([](auto...) {
return prng::gen_descreet_uniform_sign<TScaleBias>(2e-3 /*scale*/, 1000 /*range*/);
});
bias.GetTensor().generate([](auto...) {
return prng::gen_descreet_uniform_sign<TScaleBias>(2e-3 /*scale*/, 1000 /*range*/);
});
}
if(isFwdInfer)
{
estMean.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});
estVariance.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});

auto gen_value_emean = [](auto...) { return prng::gen_descreet_unsigned<TAcc>(1e-2, 100); };
estMean.InitHostData(estMean.GetTensor().desc.GetElementSize(), true, gen_value_emean);
// 0.0 to 1.0
estMean.InitHostData(estMean.GetTensor().desc.GetElementSize(), true, [](auto...) {
return prng::gen_descreet_uniform_sign<TAcc>(2e-3 /*scale*/, 1000 /*range*/);
});
estVariance.GetTensor().generate(
[](auto...) { return static_cast<TAcc>(2e-3 * (prng::gen_0_to_B(1000) + 1)); });
}
else if(isFwdTrain)
{
Expand All @@ -230,45 +235,47 @@ int BatchNormDriver<TInput, Tref, TAcc, TScaleBias, TOut>::GetandSetData()
runMean.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});
runVariance.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});

for(int i = 0; i < runVariance.GetVector().size(); i++)
{
runMean.GetVector()[i] = prng::gen_canonical<TAcc>();
runVariance.GetVector()[i] = prng::gen_canonical<TAcc>();
}
// -2.0 to 2.0
runMean.GetTensor().generate([](auto...) {
return prng::gen_descreet_uniform_sign<TAcc>(2e-3 /*scale*/, 1000 /*range*/);
});
runVariance.GetTensor().generate([](auto...) {
return prng::gen_descreet_uniform_sign<TAcc>(2e-3 /*scale*/, 1000 /*range*/);
});
}
else if(isBwd)
{
out_bwd.AllocOnHost(tensor<TOut>{bn_layout, in_len});

bnScale.AllocOnHost(tensor<TScaleBias>{bn_layout, derivedBnDesc.GetLengths()});
dy.AllocOnHost(tensor<TOut>{bn_layout, in_len});

auto gen_var_bwd = [](auto...) {
return static_cast<TOut>(1e-2 * (prng::gen_0_to_B(100) + 1));
};

dy.InitHostData(dy.GetTensor().desc.GetElementSize(), true, gen_var_bwd);
// -2.0 to 2.0
dy.InitHostData(dy.GetTensor().desc.GetElementSize(), true, [](auto...) {
return prng::gen_descreet_uniform_sign<TOut>(2e-3, 1000);
});

dScale.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});
dBias.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});
savedMean.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});
savedInvVar.AllocOnHost(tensor<TAcc>{bn_layout, derivedBnDesc.GetLengths()});

auto gen_value = [](auto...) { return prng::gen_descreet_unsigned<TScaleBias>(1e-2, 100); };
bnScale.InitHostData(bnScale.GetTensor().desc.GetElementSize(), true, gen_value);

auto gen_in_var = [](auto...) {
return static_cast<TAcc>(1e-2 * (prng::gen_0_to_B(100) + 1));
auto gen_value_bnScale = [](auto...) {
return prng::gen_descreet_uniform_sign<TScaleBias>(2e-3, 1000);
};
savedMean.InitHostData(savedMean.GetTensor().desc.GetElementSize(), true, gen_in_var);
savedInvVar.InitHostData(savedInvVar.GetTensor().desc.GetElementSize(), true, gen_in_var);
bnScale.InitHostData(bnScale.GetTensor().desc.GetElementSize(), true, gen_value_bnScale);
// -2.0 to 2.0
savedMean.InitHostData(savedMean.GetTensor().desc.GetElementSize(), true, [](auto...) {
return prng::gen_descreet_uniform_sign<TAcc>(2e-3, 1000);
});
savedInvVar.InitHostData(savedInvVar.GetTensor().desc.GetElementSize(), true, [](auto...) {
return prng::gen_descreet_uniform_sign<TAcc>(2e-3, 1000);
});
}
else
{
std::cout << "\nUnknown batch norm state!\n";
exit(EXIT_FAILURE);
}

return miopenStatusSuccess;
}

Expand Down
4 changes: 3 additions & 1 deletion src/kernels/MIOpenBatchNormFwdInferPerAct.cl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ MIOpenBatchNormFwdInferPerActivationEst(const __global _FLOAT* in,
const __global _FLOAT_PREC* __restrict bias,
double epsilon,
unsigned int batchSize,
unsigned int cLen,
unsigned int cStride,
unsigned int imageDims,
unsigned int batchStride)
{
Expand All @@ -58,7 +60,7 @@ MIOpenBatchNormFwdInferPerActivationEst(const __global _FLOAT* in,

for(int img_offset = ygid; img_offset < imageDims; img_offset += yglb_sz)
{
adjIndex = (grpid * imageDims) + img_offset;
adjIndex = (grpid * cStride) + img_offset * cLen;
mean = estimatedMean[adjIndex];
variance = estimatedVariance[adjIndex];
invVariance = rsqrt(fabs(variance + epsilon));
Expand Down
4 changes: 3 additions & 1 deletion src/kernels/MIOpenBatchNormFwdInferSpatial.cl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ MIOpenBatchNormFwdInferSpatialEst(const __global _FLOAT* __restrict in, /* x inp
const __global _FLOAT_PREC* __restrict bias,
double epsilon,
unsigned int batchSize,
unsigned int cLen,
unsigned int cStride,
unsigned int imageDims,
unsigned int batchStride)
{
Expand All @@ -66,7 +68,7 @@ MIOpenBatchNormFwdInferSpatialEst(const __global _FLOAT* __restrict in, /* x inp
{
for(int n = 0; n < batchSize; n++)
{
index = (n * batchStride) + (xgid * imageDims) + idx;
index = (n * batchStride) + (xgid * cStride) + idx * cLen;
inhat = (FLOAT2FLOATPREC(*(in + index)) - mean) * invVariance;
out[index] = FLOATPREC2FLOAT(mad(pscale, inhat, pbias));
}
Expand Down
9 changes: 4 additions & 5 deletions src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ void BatchNormForwardTraining(Handle& handle,
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdTrainingSpatialSingle,
// solver::batchnorm::BnCKFwdTraining,
solver::batchnorm::BnFwdTrainingSpatialMultiple,
solver::batchnorm::BnFwdTrainingPerActivation>{};
// solver::batchnorm::BnCKFwdTraining>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);

Expand Down Expand Up @@ -250,9 +250,8 @@ void BatchNormForwardInference(Handle& handle,
}();

const auto algo = AlgorithmName{"miopenBatchNormalizationForwardInference"};
const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdInference
// solver::batchnorm::BnCKFwdInference
>{};
const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdInference>{};
// solver::batchnorm::BnCKFwdInference>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);
}
Expand Down Expand Up @@ -395,9 +394,9 @@ void BatchNormBackward(Handle& handle,
}();

const auto solvers = solver::SolverContainer<solver::batchnorm::BnBwdTrainingSpatialSingle,
// solver::batchnorm::BnCKBwdBackward,
solver::batchnorm::BnBwdTrainingSpatialMultiple,
solver::batchnorm::BnBwdTrainingPerActivation>{};
// solver::batchnorm::BnCKBwdBackward>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);

Expand Down
24 changes: 24 additions & 0 deletions src/solver/batchnorm/backward_spatial_multiple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,33 @@ namespace solver {

namespace batchnorm {

bool BNBwdIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem)
{
size_t n, c, h, w;
std::tie(n, c, h, w) = tien<4>(problem.GetXDesc().GetLengths());

size_t in_cstride = h * w;
size_t in_nhw = n * in_cstride;

if((in_nhw >= static_cast<size_t>(32 * 1024 * 1024) || in_cstride <= 1024) && in_cstride > 512)
{
return true;
}
else
return false;
}

bool BnBwdTrainingSpatialMultiple::IsApplicable(
const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const
{
if(!problem.IsLayoutNCHW())
return false;
// NCHW is Applicable for variant = 2 only
if(!BNBwdIsCaseVariant2(problem))
{
return false;
}

if(problem.GetDirection() != miopen::batchnorm::Direction::Backward ||
problem.GetMode() != miopenBNSpatial)
return false;
Expand Down
42 changes: 30 additions & 12 deletions src/solver/batchnorm/forward_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ namespace batchnorm {
bool BnFwdInference::IsApplicable(const ExecutionContext&,
const miopen::batchnorm::ProblemDescription& bn_problem) const
{
if(bn_problem.IsLayoutNHWC())
return false;
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardInference)
return false;
if(!(bn_problem.IsFp32() or bn_problem.IsFp16() or bn_problem.IsBFp16()))
Expand Down Expand Up @@ -149,16 +147,36 @@ ConvSolution BnFwdInference::GetSolution(const ExecutionContext& context,
unsigned int in_nstride_ = c_ * h_ * w_;
unsigned int in_cstride_ = h_ * w_;

kernel(params.x,
params.y,
params.estimatedMean,
params.estimatedVariance,
params.bnScale,
params.bnBias,
params.epsilon,
n_,
in_cstride_,
in_nstride_);
if(params.xDesc->GetLayout_t() == miopenTensorNHWC)
{
kernel(params.x,
params.y,
params.estimatedMean,
params.estimatedVariance,
params.bnScale,
params.bnBias,
params.epsilon,
n_,
c_, // nhwc = c
1,
in_cstride_,
in_nstride_);
}
else
{
kernel(params.x,
params.y,
params.estimatedMean,
params.estimatedVariance,
params.bnScale,
params.bnBias,
params.epsilon,
n_,
1, // nchw 1
h_ * w_,
in_cstride_,
in_nstride_);
}
};
};

Expand Down
31 changes: 31 additions & 0 deletions src/solver/batchnorm/forward_spatial_multiple.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,40 @@ namespace solver {

namespace batchnorm {

bool BNFwdTrainIsCaseVariant2(const miopen::batchnorm::ProblemDescription& problem)
{
const auto& xDesc = problem.GetXDesc();
size_t n, c, h, w;
std::tie(n, c, h, w) = tien<4>(xDesc.GetLengths());
size_t in_cstride = h * w;
size_t in_nhw = n * in_cstride;
bool bfp32parm = xDesc.GetType() == miopenFloat;
bool bfpmixparm = (xDesc.GetType() == miopenHalf || xDesc.GetType() == miopenBFloat16) &&
problem.GetBnScale().GetType() == miopenFloat;

// NCHW is Applicable for variant = 2 only
// these number comes from BnFwdTrainingSpatialMultiple::GetSolution of
// forward_spatial_multiple.cpp
if((n >= 3 && in_cstride > 512 && (in_nhw >= 33554432 || in_cstride <= 1024) &&
((n < 256) || (in_cstride <= 60) || !bfpmixparm) && (!bfpmixparm || in_cstride <= 512)) ||
(n <= 768 || in_cstride <= 150 || !bfp32parm))
{
return true;
}
else
return false;
}

bool BnFwdTrainingSpatialMultiple::IsApplicable(
const ExecutionContext& context, const miopen::batchnorm::ProblemDescription& problem) const
{
// if NCHW check if variant is 2 else false (for all data type)
// update get solution to not change variant
if(!BNFwdTrainIsCaseVariant2(problem))
{
return false;
}

if(problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining ||
problem.GetMode() != miopenBNSpatial)
return false;
Expand Down
8 changes: 4 additions & 4 deletions test/gtest/bn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_BWD_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());

Expand All @@ -110,22 +110,22 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_BWD_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());

// fp32
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Small_FP32,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_BWD_Large_FP32,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
// // fp64
Expand Down
8 changes: 4 additions & 4 deletions test/gtest/bn_fwd_train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_FWD_Train_Large_FP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1, testBNAPIV2})),
TestNameGenerator());

Expand All @@ -116,22 +116,22 @@ INSTANTIATE_TEST_SUITE_P(Smoke,
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_OCL_FWD_Train_Large_BFP16,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1, testBNAPIV2})),
TestNameGenerator());

// // fp32
INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_FWD_Train_Small_FP32,
testing::Combine(testing::ValuesIn(NetworkSmall<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV1})),
TestNameGenerator());

INSTANTIATE_TEST_SUITE_P(Smoke,
GPU_BN_FWD_Train_Large_FP32,
testing::Combine(testing::ValuesIn(NetworkLarge<BNTestCase>()),
testing::ValuesIn({miopenTensorNCHW}),
testing::ValuesIn({miopenTensorNCHW, miopenTensorNHWC}),
testing::ValuesIn({testBNAPIV2})),
TestNameGenerator());
// // fp64
Expand Down
Loading
Loading