@@ -692,6 +692,100 @@ void ML::GetDMLTypedOperator(vector<uint8_t>& dmlTypedOpDesc, OperatorType type,
692
692
pDMLDesc->K = desc.K ;
693
693
};
694
694
695
+ static const auto getDMLBatchNormalization = [](vector<uint8_t >& dmlTypedOpDesc, const void * pOpDesc)
696
+ {
697
+ const auto & desc = *static_cast <const BatchNormalization*>(pOpDesc);
698
+
699
+ vector<uint8_t > typedFused (0 );
700
+ if (desc.pFusedActivation ) GetDMLTypedOperator (typedFused, desc.FusedActivationType , desc.pFusedActivation );
701
+
702
+ dmlTypedOpDesc.resize (sizeof (DML_BATCH_NORMALIZATION_OPERATOR_DESC) +
703
+ (desc.pFusedActivation ? sizeof (DML_OPERATOR_DESC) + typedFused.size () : 0 ));
704
+ const auto pDMLDesc = reinterpret_cast <DML_BATCH_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data ());
705
+ const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast <DML_OPERATOR_DESC*>(
706
+ &dmlTypedOpDesc[sizeof (DML_BATCH_NORMALIZATION_OPERATOR_DESC)]) : nullptr ;
707
+
708
+ pDMLDesc->InputTensor = desc.pInput ? static_cast <const DML_TENSOR_DESC*>(desc.pInput ->GetHandle ()) : nullptr ;
709
+ pDMLDesc->MeanTensor = desc.pMean ? static_cast <const DML_TENSOR_DESC*>(desc.pMean ->GetHandle ()) : nullptr ;
710
+ pDMLDesc->VarianceTensor = desc.pVariance ? static_cast <const DML_TENSOR_DESC*>(desc.pVariance ->GetHandle ()) : nullptr ;
711
+ pDMLDesc->ScaleTensor = desc.pScale ? static_cast <const DML_TENSOR_DESC*>(desc.pScale ->GetHandle ()) : nullptr ;
712
+ pDMLDesc->BiasTensor = desc.pBias ? static_cast <const DML_TENSOR_DESC*>(desc.pBias ->GetHandle ()) : nullptr ;
713
+ pDMLDesc->OutputTensor = desc.pOutput ? static_cast <const DML_TENSOR_DESC*>(desc.pOutput ->GetHandle ()) : nullptr ;
714
+ pDMLDesc->Spatial = desc.Spatial ;
715
+ pDMLDesc->Epsilon = desc.Epsilon ;
716
+ pDMLDesc->FusedActivation = pDMLFused;
717
+
718
+ if (pDMLFused)
719
+ {
720
+ assert (desc.pFusedActivation );
721
+ const auto offset = sizeof (DML_CONVOLUTION_OPERATOR_DESC) + sizeof (DML_OPERATOR_DESC);
722
+ pDMLFused->Type = GetDMLOpteratorType (desc.FusedActivationType );
723
+ pDMLFused->Desc = &dmlTypedOpDesc[offset];
724
+ memcpy (&dmlTypedOpDesc[offset], typedFused.data (), typedFused.size ());
725
+ }
726
+ };
727
+
728
+ static const auto getDMLMeanVarianceNormalization = [](vector<uint8_t >& dmlTypedOpDesc, const void * pOpDesc)
729
+ {
730
+ const auto & desc = *static_cast <const MeanVarianceNormalization*>(pOpDesc);
731
+
732
+ vector<uint8_t > typedFused (0 );
733
+ if (desc.pFusedActivation ) GetDMLTypedOperator (typedFused, desc.FusedActivationType , desc.pFusedActivation );
734
+
735
+ dmlTypedOpDesc.resize (sizeof (DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC) +
736
+ (desc.pFusedActivation ? sizeof (DML_OPERATOR_DESC) + typedFused.size () : 0 ));
737
+ const auto pDMLDesc = reinterpret_cast <DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data ());
738
+ const auto pDMLFused = desc.pFusedActivation ? reinterpret_cast <DML_OPERATOR_DESC*>(
739
+ &dmlTypedOpDesc[sizeof (DML_MEAN_VARIANCE_NORMALIZATION_OPERATOR_DESC)]) : nullptr ;
740
+
741
+ pDMLDesc->InputTensor = desc.pInput ? static_cast <const DML_TENSOR_DESC*>(desc.pInput ->GetHandle ()) : nullptr ;
742
+ pDMLDesc->ScaleTensor = desc.pScale ? static_cast <const DML_TENSOR_DESC*>(desc.pScale ->GetHandle ()) : nullptr ;
743
+ pDMLDesc->BiasTensor = desc.pBias ? static_cast <const DML_TENSOR_DESC*>(desc.pBias ->GetHandle ()) : nullptr ;
744
+ pDMLDesc->OutputTensor = desc.pOutput ? static_cast <const DML_TENSOR_DESC*>(desc.pOutput ->GetHandle ()) : nullptr ;
745
+ pDMLDesc->CrossChannel = desc.CrossChannel ;
746
+ pDMLDesc->NormalizeVariance = desc.NormalizeVariance ;
747
+ pDMLDesc->Epsilon = desc.Epsilon ;
748
+ pDMLDesc->FusedActivation = pDMLFused;
749
+
750
+ if (pDMLFused)
751
+ {
752
+ assert (desc.pFusedActivation );
753
+ const auto offset = sizeof (DML_CONVOLUTION_OPERATOR_DESC) + sizeof (DML_OPERATOR_DESC);
754
+ pDMLFused->Type = GetDMLOpteratorType (desc.FusedActivationType );
755
+ pDMLFused->Desc = &dmlTypedOpDesc[offset];
756
+ memcpy (&dmlTypedOpDesc[offset], typedFused.data (), typedFused.size ());
757
+ }
758
+ };
759
+
760
+ static const auto getDMLLocalResponseNormalization = [](vector<uint8_t >& dmlTypedOpDesc, const void * pOpDesc)
761
+ {
762
+ dmlTypedOpDesc.resize (sizeof (DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC));
763
+ const auto pDMLDesc = reinterpret_cast <DML_LOCAL_RESPONSE_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data ());
764
+ const auto & desc = *static_cast <const LocalResponseNormalization*>(pOpDesc);
765
+
766
+ pDMLDesc->InputTensor = desc.pInput ? static_cast <const DML_TENSOR_DESC*>(desc.pInput ->GetHandle ()) : nullptr ;
767
+ pDMLDesc->OutputTensor = desc.pOutput ? static_cast <const DML_TENSOR_DESC*>(desc.pOutput ->GetHandle ()) : nullptr ;
768
+ pDMLDesc->CrossChannel = desc.CrossChannel ;
769
+ pDMLDesc->LocalSize = desc.LocalSize ;
770
+ pDMLDesc->Alpha = desc.Alpha ;
771
+ pDMLDesc->Beta = desc.Beta ;
772
+ pDMLDesc->Bias = desc.Bias ;
773
+ };
774
+
775
+ static const auto getDMLLPNormalization = [](vector<uint8_t >& dmlTypedOpDesc, const void * pOpDesc)
776
+ {
777
+ dmlTypedOpDesc.resize (sizeof (DML_LP_NORMALIZATION_OPERATOR_DESC));
778
+ const auto pDMLDesc = reinterpret_cast <DML_LP_NORMALIZATION_OPERATOR_DESC*>(dmlTypedOpDesc.data ());
779
+ const auto & desc = *static_cast <const LPNormalization*>(pOpDesc);
780
+
781
+ pDMLDesc->InputTensor = desc.pInput ? static_cast <const DML_TENSOR_DESC*>(desc.pInput ->GetHandle ()) : nullptr ;
782
+ pDMLDesc->OutputTensor = desc.pOutput ? static_cast <const DML_TENSOR_DESC*>(desc.pOutput ->GetHandle ()) : nullptr ;
783
+ pDMLDesc->Axis = desc.Axis ;
784
+ pDMLDesc->Epsilon = desc.Epsilon ;
785
+ pDMLDesc->P = desc.P ;
786
+ };
787
+
788
+
695
789
static const function<void (vector<uint8_t >&, const void *)> pfnGetDMLOps[] =
696
790
{
697
791
nullptr , // INVALID
@@ -768,6 +862,11 @@ void ML::GetDMLTypedOperator(vector<uint8_t>& dmlTypedOpDesc, OperatorType type,
768
862
getDMLSpaceDepth, // DEPTH_TO_SPACE
769
863
getDMLTile, // TILE
770
864
getDMLTopK, // TOP_K
865
+
866
+ getDMLBatchNormalization, // BATCH_NORMALIZATION
867
+ getDMLMeanVarianceNormalization, // MEAN_VARIANCE_NORMALIZATION
868
+ getDMLLocalResponseNormalization, // LOCAL_RESPONSE_NORMALIZATION
869
+ getDMLLPNormalization, // LP_NORMALIZATION
771
870
};
772
871
773
872
pfnGetDMLOps[static_cast <uint32_t >(type)](dmlTypedOpDesc, pOpDesc);
0 commit comments