Skip to content

Commit a22007b

Browse files
authored
[NeoML] CTiedEmbeddingsLayer is learnable with empty paramBlobs (#1106)
Signed-off-by: Kirill Golikov <kirill.golikov@abbyy.com>
1 parent ef317f2 commit a22007b

File tree

9 files changed

+28
-26
lines changed

9 files changed

+28
-26
lines changed

NeoML/include/NeoML/Dnn/Dnn.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,10 @@ class NEOML_API CBaseLayer : public virtual IObject {
244244
bool IsBackwardPerformed() const;
245245
// Indicates that backpropagation must be performed for the layer when Learn method is called
246246
bool IsBackwardNeeded() const;
247-
// Layer may contain empty paramBlob of given index
248-
virtual bool ContainsEmptyParamBlob( int ) const { return false; }
247+
// Layer may contain null paramBlob of given index, specialization for transferParamsBlob
248+
virtual bool ContainsNullParamBlob( int ) const { return false; }
249+
// Special case, specialization for transferParamsBlob
250+
virtual bool IsLearnableWithEmptyParamBlobs() const { return false; }
249251
// Gets a pointer to the layer connected to the given input
250252
CBaseLayer* GetInputLayer(int input) { return inputLinks[input].Layer; }
251253
const CBaseLayer* GetInputLayer(int input) const { return inputLinks[input].Layer; }

NeoML/include/NeoML/Dnn/DnnSolver.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class NEOML_API CDnnSolver : virtual public IObject {
2929
public:
3030
// Stores the calculated values of layer parameters gradients for further use in Train method
3131
// forSharedWeightsLayer=true should only be used within layers that share weights with other layers.
32-
void AddDiff( CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramDiffBlobs,
32+
void AddDiff( const CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramDiffBlobs,
3333
bool sharedWeights = false );
3434

3535
// Modifies the trainable parameters of the network layers,
@@ -98,8 +98,8 @@ class NEOML_API CDnnSolver : virtual public IObject {
9898
// Used in the inheriting classes
9999
CMap<CString, CObjectArray<CDnnBlob>> layerToGradientHistory;
100100
// Layers which require reduction across distributed solver
101-
CHashTable<CBaseLayer*> layersToReduce; // Fast check if layer is included already
102-
CArray<CBaseLayer*> reduceOrder; // Correct order across all of the distributed nets
101+
CHashTable<const CBaseLayer*> layersToReduce; // Fast check if layer is included already
102+
CArray<const CBaseLayer*> reduceOrder; // Correct order across all of the distributed nets
103103

104104
// Averages weights over all threads
105105
void allReduce( float distributedCoeff );

NeoML/include/NeoML/Dnn/Layers/ChannelwiseWith1x1Layer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class NEOML_API CChannelwiseWith1x1Layer : public CBaseLayer {
5858
void RunOnce() override;
5959
void BackwardOnce() override { NeoAssert( false ); }
6060
// Specialization for transferParamsBlob
61-
bool ContainsEmptyParamBlob( int i ) const override
61+
bool ContainsNullParamBlob( int i ) const override
6262
{ return paramBlobs[i] == nullptr && ( i == P_ChannelwiseFreeTerm || i == P_ConvFreeTerm ); }
6363

6464
private:

NeoML/include/NeoML/Dnn/Layers/MobileNetV2BlockLayer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class NEOML_API CMobileNetV2BlockLayer : public CBaseLayer {
7575
void RunOnce() override;
7676
void BackwardOnce() override { NeoAssert( false ); }
7777
// Specialization for transferParamsBlob
78-
bool ContainsEmptyParamBlob( int i ) const override
78+
bool ContainsNullParamBlob( int i ) const override
7979
{ return !paramBlobs[i] && ( i == P_ChannelwiseFreeTerm || i == P_DownFreeTerm || i == P_ExpandFreeTerm ); }
8080

8181
private:

NeoML/include/NeoML/Dnn/Layers/MobileNetV3BlockLayer.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class NEOML_API CMobileNetV3PreSEBlockLayer : public CBaseLayer {
5959
void RunOnce() override;
6060
void BackwardOnce() override { NeoAssert( false ); }
6161
// Specialization for transferParamsBlob
62-
bool ContainsEmptyParamBlob( int i ) const override
62+
bool ContainsNullParamBlob( int i ) const override
6363
{ return paramBlobs[i] == nullptr && ( i == P_ChannelwiseFreeTerm || i == P_ExpandFreeTerm ); }
6464

6565
private:
@@ -115,7 +115,7 @@ class NEOML_API CMobileNetV3PostSEBlockLayer : public CBaseLayer {
115115
void RunOnce() override;
116116
void BackwardOnce() override { NeoAssert( false ); }
117117
// Specialization for transferParamsBlob
118-
bool ContainsEmptyParamBlob( int i ) const override
118+
bool ContainsNullParamBlob( int i ) const override
119119
{ return paramBlobs[i] == nullptr && ( i == P_DownFreeTerm ); }
120120

121121
private:

NeoML/include/NeoML/Dnn/Layers/TiedEmbeddingsLayer.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ class CMultichannelLookupLayer;
2727
class NEOML_API CTiedEmbeddingsLayer : public CBaseLayer {
2828
NEOML_DNN_LAYER( CTiedEmbeddingsLayer )
2929
public:
30-
explicit CTiedEmbeddingsLayer( IMathEngine& mathEngine );
30+
explicit CTiedEmbeddingsLayer( IMathEngine& mathEngine ) :
31+
CBaseLayer( mathEngine, "CTiedEmbeddingsLayer", /*isLearnable*/true ) {}
3132

3233
void Serialize( CArchive& archive ) override;
3334

@@ -54,17 +55,19 @@ class NEOML_API CTiedEmbeddingsLayer : public CBaseLayer {
5455
void LearnOnce() override;
5556
int BlobsForBackward() const override { return 0; }
5657
int BlobsForLearn() const override { return TInputBlobs; }
58+
// Special case, specialization for transferParamsBlob
59+
bool IsLearnableWithEmptyParamBlobs() const override { return true; }
5760

5861
private:
5962
// Path for embedding layer from which matrix is taken
6063
// Now it contains the path as array
6164
// So in case of no composite layer it is gonna be { "embeddingName" }
6265
CArray<CString> embeddingPath;
6366
// Channel index in embedding layer.
64-
int channelIndex;
67+
int channelIndex = 0;
6568

6669
const CDnnBlob* getEmbeddingsTable() const;
67-
CMultichannelLookupLayer* getLookUpLayer() const;
70+
const CMultichannelLookupLayer* getLookUpLayer() const;
6871
};
6972

7073
// Tied embeddings.

NeoML/src/Dnn/BaseLayer.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,17 @@ void CBaseLayer::transferParamsBlob( CBaseLayer& dist ) const
314314
}
315315
} else {
316316
NeoAssertMsg( dist.paramBlobs.Size() == paramBlobs.Size(), "transferParamsBlob: It isn't a copy of the layer" );
317+
if( IsLearnableWithEmptyParamBlobs() ) { // Special case is CTiedEmbeddingsLayer
318+
NeoAssert( dist.IsLearnable() && paramBlobs.Size() == 0 );
319+
return;
320+
}
317321

318322
NeoAssertMsg( !dist.IsLearnable() || paramBlobs.Size() > 0,
319323
"transferParamsBlob: The origin dnn should be trained and reshaped to create a reference dnn" );
320324
// Create reference copy of dist.paramBlobs with shared buffer
321325
// Takes a pointer to parent's blob to access memory
322326
for( int j = 0; j < dist.paramBlobs.Size(); ++j ) {
323-
if( ContainsEmptyParamBlob( j ) ) {
327+
if( ContainsNullParamBlob( j ) ) {
324328
dist.paramBlobs[j] = nullptr; // may contain empty parameter
325329
continue;
326330
}

NeoML/src/Dnn/DnnSolver.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ CDnnSolver::CDnnSolver( IMathEngine& _mathEngine ) :
140140
}
141141

142142
// Calculates the layer parameter gradients to then use them in Train method
143-
void CDnnSolver::AddDiff( CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramDiffBlobs, bool sharedWeights )
143+
void CDnnSolver::AddDiff( const CBaseLayer* layer, const CObjectArray<CDnnBlob>& paramDiffBlobs, bool sharedWeights )
144144
{
145145
NeoAssert( layer != nullptr );
146146
if( MathEngine().IsDistributed() && !layersToReduce.Has( layer ) ) {

NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ limitations under the License.
2121

2222
namespace NeoML {
2323

24-
CTiedEmbeddingsLayer::CTiedEmbeddingsLayer( IMathEngine& mathEngine ) :
25-
CBaseLayer( mathEngine, "CTiedEmbeddingsLayer", true ),
26-
channelIndex( 0 )
27-
{
28-
}
29-
3024
void CTiedEmbeddingsLayer::SetChannelIndex( int val )
3125
{
3226
NeoAssert( val >= 0 );
@@ -131,7 +125,7 @@ void CTiedEmbeddingsLayer::LearnOnce()
131125
diffBlob->Clear();
132126
}
133127

134-
CMultichannelLookupLayer* embeddingsLayer = getLookUpLayer();
128+
const CMultichannelLookupLayer* embeddingsLayer = getLookUpLayer();
135129
CObjectArray<CDnnBlob> totalDiffBlobs;
136130
const int channelsCount = embeddingsLayer->GetDimensions().Size();
137131
for( int i = 0; i < channelsCount; i++ ) {
@@ -144,7 +138,7 @@ void CTiedEmbeddingsLayer::LearnOnce()
144138
}
145139
}
146140

147-
GetDnn()->GetSolver()->AddDiff( embeddingsLayer, totalDiffBlobs, true );
141+
GetDnn()->GetSolver()->AddDiff( embeddingsLayer, totalDiffBlobs, /*sharedWeights*/true );
148142
}
149143

150144
// Embeddings matrix
@@ -155,11 +149,10 @@ const CDnnBlob* CTiedEmbeddingsLayer::getEmbeddingsTable() const
155149
return getLookUpLayer()->GetEmbeddings( channelIndex );
156150
}
157151

158-
CMultichannelLookupLayer* CTiedEmbeddingsLayer::getLookUpLayer() const
152+
const CMultichannelLookupLayer* CTiedEmbeddingsLayer::getLookUpLayer() const
159153
{
160-
CMultichannelLookupLayer* embeddingsLayer;
161-
embeddingsLayer = CheckCast<CMultichannelLookupLayer>(
162-
const_cast<CDnn*>(GetDnn())->GetLayer(embeddingPath).Ptr());
154+
const CMultichannelLookupLayer* embeddingsLayer
155+
= CheckCast<CMultichannelLookupLayer>( GetDnn()->GetLayer( embeddingPath ).Ptr() );
163156
return embeddingsLayer;
164157
}
165158

0 commit comments

Comments
 (0)