Skip to content

Commit ee1fe00

Browse files
committed
Add normalization ops
1 parent c67f5b4 commit ee1fe00

File tree

3 files changed

+149
-7
lines changed

3 files changed

+149
-7
lines changed

XUSGMachineLearning/MachineLearning/XUSGMachineLearning.h

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ namespace XUSG
109109
enum class TensorFlag
110110
{
111111
NONE,
112-
MANAGED,
112+
MANAGED
113113
};
114114

115115
DEFINE_ENUM_FLAG_OPERATORS(TensorFlag);
@@ -119,7 +119,7 @@ namespace XUSG
119119
NONE,
120120
ALLOW_HALF_PRECISION_COMPUTATION,
121121
DISABLE_META_COMMANDS,
122-
DESCRIPTORS_VOLATILE,
122+
DESCRIPTORS_VOLATILE
123123
};
124124

125125
DEFINE_ENUM_FLAG_OPERATORS(ExecutionFlag);
@@ -521,6 +521,54 @@ namespace XUSG
521521
uint32_t K;
522522
};
523523

524+
struct BatchNormalization
525+
526+
{
527+
const Tensor* pInput;
528+
const Tensor* pMean;
529+
const Tensor* pVariance;
530+
const Tensor* pScale;
531+
const Tensor* pBias;
532+
const Tensor* pOutput;
533+
bool Spatial;
534+
float Epsilon;
535+
OperatorType FusedActivationType;
536+
const void* pFusedActivation;
537+
};
538+
539+
struct MeanVarianceNormalization
540+
{
541+
const Tensor* pInput;
542+
const Tensor* pScale;
543+
const Tensor* pBias;
544+
const Tensor* pOutput;
545+
bool CrossChannel;
546+
bool NormalizeVariance;
547+
float Epsilon;
548+
OperatorType FusedActivationType;
549+
const void* pFusedActivation;
550+
};
551+
552+
struct LocalResponseNormalization
553+
{
554+
const Tensor* pInput;
555+
const Tensor* pOutput;
556+
bool CrossChannel;
557+
uint32_t LocalSize;
558+
float Alpha;
559+
float Beta;
560+
float Bias;
561+
};
562+
563+
struct LPNormalization
564+
{
565+
const Tensor* pInput;
566+
const Tensor* pOutput;
567+
uint32_t Axis;
568+
float Epsilon;
569+
uint32_t P;
570+
};
571+
524572
//--------------------------------------------------------------------------------------
525573
// Device
526574
//--------------------------------------------------------------------------------------

XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,100 @@ void ML::GetDMLTypedOperator(vector<uint8_t>& dmlTypedOpDesc, OperatorType type,
692692
pDMLDesc->K = desc.K;
693693
};
694694

695+
static const auto getDMLBatchNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
696+
{
697+
const auto& desc = *static_cast<const BatchNormalization*>(pOpDesc);
698+
699+
vector<uint8_t> typedFused(0);
700+
if (desc.pFusedActivation) GetDMLTypedOperator(typedFused, desc.FusedActivationType, desc.pFusedActivation);
701+
702+
dmlTypedOpDesc.resize(sizeof(DML_BATCH_NORMALIZATION_OPERATOR_DESC) +
703+
(desc.pFusedActivation ? sizeof(DML_OPERATOR_DESC) + typedFused.size() : 0));
704+
const auto pDMLDesc = reinterpret_cast<DML_BATCH_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
705+
const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast<DML_OPERATOR_DESC*>(
706+
&dmlTypedOpDesc[sizeof(DML_BATCH_NORMALIZATION_OPERATOR_DESC)]) : nullptr;
707+
708+
pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
709+
pDMLDesc->MeanTensor = desc.pMean ? static_cast<const DML_TENSOR_DESC*>(desc.pMean->GetHandle()) : nullptr;
710+
pDMLDesc->VarianceTensor = desc.pVariance ? static_cast<const DML_TENSOR_DESC*>(desc.pVariance->GetHandle()) : nullptr;
711+
pDMLDesc->ScaleTensor = desc.pScale ? static_cast<const DML_TENSOR_DESC*>(desc.pScale->GetHandle()) : nullptr;
712+
pDMLDesc->BiasTensor = desc.pBias ? static_cast<const DML_TENSOR_DESC*>(desc.pBias->GetHandle()) : nullptr;
713+
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
714+
pDMLDesc->Spatial = desc.Spatial;
715+
pDMLDesc->Epsilon = desc.Epsilon;
716+
pDMLDesc->FusedActivation = pDMLFused;
717+
718+
if (pDMLFused)
719+
{
720+
assert(desc.pFusedActivation);
721+
const auto offset = sizeof(DML_CONVOLUTION_OPERATOR_DESC) + sizeof(DML_OPERATOR_DESC);
722+
pDMLFused->Type = GetDMLOpteratorType(desc.FusedActivationType);
723+
pDMLFused->Desc = &dmlTypedOpDesc[offset];
724+
memcpy(&dmlTypedOpDesc[offset], typedFused.data(), typedFused.size());
725+
}
726+
};
727+
728+
static const auto getDMLMeanVarianceNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
729+
{
730+
const auto& desc = *static_cast<const MeanVarianceNormalization*>(pOpDesc);
731+
732+
vector<uint8_t> typedFused(0);
733+
if (desc.pFusedActivation) GetDMLTypedOperator(typedFused, desc.FusedActivationType, desc.pFusedActivation);
734+
735+
dmlTypedOpDesc.resize(sizeof(DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC) +
736+
(desc.pFusedActivation ? sizeof(DML_OPERATOR_DESC) + typedFused.size() : 0));
737+
const auto pDMLDesc = reinterpret_cast<DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
738+
const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast<DML_OPERATOR_DESC*>(
739+
&dmlTypedOpDesc[sizeof(DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC)]) : nullptr;
740+
741+
pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
742+
pDMLDesc->ScaleTensor = desc.pScale ? static_cast<const DML_TENSOR_DESC*>(desc.pScale->GetHandle()) : nullptr;
743+
pDMLDesc->BiasTensor = desc.pBias ? static_cast<const DML_TENSOR_DESC*>(desc.pBias->GetHandle()) : nullptr;
744+
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
745+
pDMLDesc->CrossChannel = desc.CrossChannel;
746+
pDMLDesc->NormalizeVariance = desc.NormalizeVariance;
747+
pDMLDesc->Epsilon = desc.Epsilon;
748+
pDMLDesc->FusedActivation = pDMLFused;
749+
750+
if (pDMLFused)
751+
{
752+
assert(desc.pFusedActivation);
753+
const auto offset = sizeof(DML_CONVOLUTION_OPERATOR_DESC) + sizeof(DML_OPERATOR_DESC);
754+
pDMLFused->Type = GetDMLOpteratorType(desc.FusedActivationType);
755+
pDMLFused->Desc = &dmlTypedOpDesc[offset];
756+
memcpy(&dmlTypedOpDesc[offset], typedFused.data(), typedFused.size());
757+
}
758+
};
759+
760+
static const auto getDMLLocalResponseNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
761+
{
762+
dmlTypedOpDesc.resize(sizeof(DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC));
763+
const auto pDMLDesc = reinterpret_cast<DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
764+
const auto& desc = *static_cast<const LocalResponseNormalization*>(pOpDesc);
765+
766+
pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
767+
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
768+
pDMLDesc->CrossChannel = desc.CrossChannel;
769+
pDMLDesc->LocalSize = desc.LocalSize;
770+
pDMLDesc->Alpha = desc.Alpha;
771+
pDMLDesc->Beta = desc.Beta;
772+
pDMLDesc->Bias = desc.Bias;
773+
};
774+
775+
static const auto getDMLLPNormalization = [](vector<uint8_t>& dmlTypedOpDesc, const void* pOpDesc)
776+
{
777+
dmlTypedOpDesc.resize(sizeof(DML_LP_NORMALIZATION_OPERATOR_DESC));
778+
const auto pDMLDesc = reinterpret_cast<DML_LP_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data());
779+
const auto& desc = *static_cast<const LPNormalization*>(pOpDesc);
780+
781+
pDMLDesc->InputTensor = desc.pInput ? static_cast<const DML_TENSOR_DESC*>(desc.pInput->GetHandle()) : nullptr;
782+
pDMLDesc->OutputTensor = desc.pOutput ? static_cast<const DML_TENSOR_DESC*>(desc.pOutput->GetHandle()) : nullptr;
783+
pDMLDesc->Axis = desc.Axis;
784+
pDMLDesc->Epsilon = desc.Epsilon;
785+
pDMLDesc->P = desc.P;
786+
};
787+
788+
695789
static const function<void(vector<uint8_t>&, const void*)> pfnGetDMLOps[] =
696790
{
697791
nullptr, // INVALID
@@ -768,6 +862,11 @@ void ML::GetDMLTypedOperator(vector<uint8_t>& dmlTypedOpDesc, OperatorType type,
768862
getDMLSpaceDepth, // DEPTH_TO_SPACE
769863
getDMLTile, // TILE
770864
getDMLTopK, // TOP_K
865+
866+
getDMLBatchNormalization, // BATCH_NORMALIZATION
867+
getDMLMeanVarianceNormalization, // MEAN_VARIANCE_NORMALIZATION
868+
getDMLLocalResponseNormalization, // LOCAL_RESPONSE_NORMALIZATION
869+
getDMLLPNormalization, // LP_NORMALIZATION
771870
};
772871

773872
pfnGetDMLOps[static_cast<uint32_t>(type)](dmlTypedOpDesc, pOpDesc);

XUSGMachineLearning/MachineLearning/XUSGMachineLearning_DML.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ namespace XUSG
4747
com_ptr<IDMLDevice> m_device;
4848
};
4949

50-
using BatchNormalization = DML_BATCH_NORMALIZATION_OPERATOR_DESC;
51-
using MeanVarianceNormalization = DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC;
52-
using LocalResponseNormalization = DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC;
53-
using LPNormalization = DML_LP_NORMALIZATION_OPERATOR_DESC;
54-
5550
using RNNOperator = DML_RNN_OPERATOR_DESC;
5651
using LSTMOperator = DML_LSTM_OPERATOR_DESC;
5752
using GRUOperator = DML_GRU_OPERATOR_DESC;

0 commit comments

Comments
 (0)