Skip to content

Commit

Permalink
[NeoML] CTiedEmbeddingsLayer extend functional interface (#1107)
Browse files Browse the repository at this point in the history
Signed-off-by: Kirill Golikov <kirill.golikov@abbyy.com>
  • Loading branch information
favorart authored Sep 9, 2024
1 parent a22007b commit 1c369e3
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 29 deletions.
57 changes: 34 additions & 23 deletions NeoML/include/NeoML/Dnn/DnnLambdaHolder.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright © 2017-2020 ABBYY Production LLC
/* Copyright © 2017-2024 ABBYY
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -15,58 +15,70 @@ limitations under the License.

#pragma once

#include <type_traits>
#include <NeoML/NeoMLDefs.h>

namespace NeoML {

//////////////////////////////////////////////////////////////////////////////////////////

// Simple analog for std::function() that does not use std::allocator

namespace details {

template<typename T>
class CLambdaHolderBase {};
class ILambdaHolderBase;

// Base class for lambda holder. This interface hide actual lambda type.
template<typename Out, typename ...In>
class CLambdaHolderBase<Out( In... )> : public virtual IObject {
class ILambdaHolderBase<Out( In... )> : public virtual IObject {
public:
// Executes lambda.
virtual Out Execute( In... arguments ) = 0;
// Copies lambda.
virtual CPtr<CLambdaHolderBase<Out( In... )>> Copy() = 0;
};

//---------------------------------------------------------------------------------------------------------------------

template<typename T, typename U>
class CLambdaHolder {};
class CLambdaHolder;

// Lambda holder implementation.
template<typename T, typename Out, typename ...In>
class CLambdaHolder<T, Out( In... )> : public CLambdaHolderBase<Out( In... )> {
template<typename F, typename Out, typename ...In>
class CLambdaHolder<F, Out( In... )> : public ILambdaHolderBase<Out( In... )> {
public:
CLambdaHolder( T _lambda ) : lambda( _lambda ) {}
CLambdaHolder( F&& func ) : lambda( std::move( func ) ) {}
CLambdaHolder( const F& func ) : lambda( func ) {}

Out Execute( In... in ) override
{ return lambda( in... ); }

CPtr<CLambdaHolderBase<Out( In... )>> Copy() override
{ return new CLambdaHolder<T, Out( In... )>( lambda ); }

private:
T lambda;
F lambda;
};

} // namespace details

//---------------------------------------------------------------------------------------------------------------------

template<typename T>
class CLambda {};
class CLambda;

// Type that captures lambda.
template<typename Out, typename ...In>
class CLambda<Out( In... )> {
public:
CLambda() {}
template<class T>
CLambda( const T& t ) : lambda( new CLambdaHolder<T, Out( In... )>( t ) ) {}
CLambda( const CLambda& other ) :
lambda( other.lambda != 0 ? other.lambda->Copy() : nullptr ) {}
CLambda() = default;
// Be copied and moved by default, because it stores the shared pointer

// Convert from a function, except itself type
// By coping
template<class F,
typename std::enable_if<!std::is_same<CLambda, typename std::decay<F>::type>::value, int>::type = 0>
CLambda( const F& function ) :
lambda( new details::CLambdaHolder<F, Out( In... )>( function ) ) {}
// By moving
template<class F,
typename std::enable_if<!std::is_same<CLambda, typename std::decay<F>::type>::value, int>::type = 0>
CLambda( F&& function ) :
lambda( new details::CLambdaHolder<F, Out( In... )>( std::move( function ) ) ) {}

Out operator()( In... in )
{
Expand All @@ -78,8 +90,7 @@ class CLambda<Out( In... )> {
bool IsEmpty() const { return lambda == nullptr; }

private:
CPtr<CLambdaHolderBase<Out( In... )>> lambda;
CPtr<details::ILambdaHolderBase<Out( In... )>> lambda;
};

//////////////////////////////////////////////////////////////////////////////////////////
} // namespace NeoML
3 changes: 2 additions & 1 deletion NeoML/include/NeoML/Dnn/Layers/TiedEmbeddingsLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class NEOML_API CTiedEmbeddingsLayer : public CBaseLayer {
};

// Tied embeddings.
NEOML_API CLayerWrapper<CTiedEmbeddingsLayer> TiedEmbeddings( const char* name, int channel );
NEOML_API CLayerWrapper<CTiedEmbeddingsLayer> TiedEmbeddings( const char* name, int channel,
CArray<CString>&& embeddingPath = {} );

} // namespace NeoML
14 changes: 9 additions & 5 deletions NeoML/src/Dnn/Layers/TiedEmbeddingsLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,16 @@ const CMultichannelLookupLayer* CTiedEmbeddingsLayer::getLookUpLayer() const
return embeddingsLayer;
}

CLayerWrapper<CTiedEmbeddingsLayer> TiedEmbeddings( const char* name, int channel )
CLayerWrapper<CTiedEmbeddingsLayer> TiedEmbeddings( const char* name, int channel, CArray<CString>&& embeddingPath )
{
return CLayerWrapper<CTiedEmbeddingsLayer>( "TiedEmbeddings", [=]( CTiedEmbeddingsLayer* result ) {
result->SetEmbeddingsLayerName( name );
result->SetChannelIndex( channel );
} );
return CLayerWrapper<CTiedEmbeddingsLayer>( "TiedEmbeddings",
[=, path=std::move( embeddingPath )]( CTiedEmbeddingsLayer* result )
{
result->SetEmbeddingsLayerName( name );
result->SetChannelIndex( channel );
result->SetEmbeddingsLayerPath( path );
}
);
}

} // namespace NeoML

0 comments on commit 1c369e3

Please sign in to comment.