Skip to content

Commit

Permalink
New OrthogonalProduct*Factory.getMarginal()
Browse files Browse the repository at this point in the history
New OrthogonalProductPolynomialFactory.getMarginal()
New OrthogonalProductFunctionFactory.getMarginal()
New OrthogonalProductFunctionFactory.build(indices)
New OrthogonalProductPolynomialFactory.build(indices)

Closes openturns#2648

Add a new unit test to getFunctionFamilyCollection()
  • Loading branch information
mbaudin47 committed May 14, 2024
1 parent 82f2a47 commit aedf2c5
Show file tree
Hide file tree
Showing 15 changed files with 484 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

BEGIN_NAMESPACE_OPENTURNS



TEMPLATE_CLASSNAMEINIT(PersistentCollection<OrthogonalUniVariateFunctionFamily>)

static const Factory<PersistentCollection<OrthogonalUniVariateFunctionFamily> > Factory_PersistentCollection_OrthogonalUniVariateFunctionFamily;
Expand All @@ -53,7 +51,6 @@ OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory()
// Nothing to do
}


/* Constructor */
OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const FunctionFamilyCollection & coll)
: OrthogonalFunctionFactory()
Expand All @@ -62,7 +59,6 @@ OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const Functio
buildMeasure(coll);
}


/* Constructor */
OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const FunctionFamilyCollection & coll,
const EnumerateFunction & phi)
Expand All @@ -73,14 +69,12 @@ OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const Functio
buildMeasure(coll);
}


/* Virtual constructor */
OrthogonalProductFunctionFactory * OrthogonalProductFunctionFactory::clone() const
{
return new OrthogonalProductFunctionFactory(*this);
}


/* Return the enumerate function that translate unidimensional indices into multidimensional indices */
EnumerateFunction OrthogonalProductFunctionFactory::getEnumerateFunction() const
{
Expand All @@ -99,13 +93,17 @@ OrthogonalProductFunctionFactory::FunctionFamilyCollection OrthogonalProductFunc
return coll;
}


/* Build the Function of the given index */
Function OrthogonalProductFunctionFactory::build(const UnsignedInteger index) const
{
return tensorizedFunctionFactory_.build(index);
}

/* Build the Function of the given index */
Function OrthogonalProductFunctionFactory::build(const Indices & indices) const
{
return tensorizedFunctionFactory_.build(getEnumerateFunction().inverse(indices));
}

/* String converter */
String OrthogonalProductFunctionFactory::__repr__() const
Expand All @@ -115,15 +113,13 @@ String OrthogonalProductFunctionFactory::__repr__() const
<< " measure=" << measure_;
}


/* Method save() stores the object through the StorageManager */
void OrthogonalProductFunctionFactory::save(Advocate & adv) const
{
OrthogonalFunctionFactory::save(adv);
adv.saveAttribute("tensorizedFunctionFactory_", tensorizedFunctionFactory_);
}


/* Method load() reloads the object from the StorageManager */
void OrthogonalProductFunctionFactory::load(Advocate & adv)
{
Expand Down Expand Up @@ -157,5 +153,17 @@ void OrthogonalProductFunctionFactory::buildMeasure(const FunctionFamilyCollecti
measure_ = JointDistribution(distributions);
}

/* Get marginal functions */
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const
{
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionColl(tensorizedFunctionFactory_.getFunctionFamilyCollection());
const UnsignedInteger size = functionColl.getSize();
indices.check(size);
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionMarginalCollection;
for (UnsignedInteger index = 0; index < size; ++ index)
if (indices.contains(index))
functionMarginalCollection.add(functionColl[index]);
return functionMarginalCollection;
}

END_NAMESPACE_OPENTURNS
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,16 @@ Sample OrthogonalProductPolynomialFactory::getNodesAndWeights(const Indices & de
return nodes;
}

/* Get marginal functions */
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const
{
const UnsignedInteger size = coll_.getSize();
indices.check(size);
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection polynomialMarginalCollection;
for (UnsignedInteger index = 0; index < size; ++ index)
if (indices.contains(index))
polynomialMarginalCollection.add(coll_[index]);
return polynomialMarginalCollection;
}

END_NAMESPACE_OPENTURNS
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,18 @@ public:
/** Build the Function of the given index */
Function build(const UnsignedInteger index) const override;

/** Build the Function of the given multi-indices */
Function build(const Indices & indices) const;

/** Return the enumerate function that translate unidimensional indices nto multidimensional indices */
EnumerateFunction getEnumerateFunction() const override;

/** Return the collection of univariate orthogonal polynomial families */
FunctionFamilyCollection getFunctionFamilyCollection() const;

/** Get marginal functions */
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection getMarginal(const Indices & indices) const;

/** Virtual constructor */
OrthogonalProductFunctionFactory * clone() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#define OPENTURNS_ORTHOGONALPRODUCTPOLYNOMIALFACTORY_HXX

#include "openturns/OrthogonalFunctionFactory.hxx"
#include "openturns/OrthogonalProductFunctionFactory.hxx"
#include "openturns/Distribution.hxx"
#include "openturns/Indices.hxx"
#include "openturns/Point.hxx"
Expand Down Expand Up @@ -80,6 +81,9 @@ public:
Sample getNodesAndWeights(const Indices & degrees,
Point & weightsOut) const;

/** Get marginal functions */
PolynomialFamilyCollection getMarginal(const Indices & indices) const;

/** String converter */
String __repr__() const override;
String __str__(const String & offset = "") const override;
Expand Down
1 change: 1 addition & 0 deletions lib/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ ot_check_test (LinearModelAlgorithm_std)
ot_check_test (LinearModelAnalysis_std)
ot_check_test (KrigingAlgorithm_isotropic_std IGNOREOUT)
ot_check_test (OrthogonalProductPolynomialFactory_std)
ot_check_test (OrthogonalProductFunctionFactory_std)

if (HMAT_FOUND)
ot_check_test (KrigingAlgorithm_std_hmat)
Expand Down
121 changes: 121 additions & 0 deletions lib/test/t_OrthogonalProductFunctionFactory_std.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// -*- C++ -*-
/**
* @brief The test file of OrthogonalProductFunctionFactory class
*
* Copyright 2005-2024 Airbus-EDF-IMACS-ONERA-Phimeca
*
* This library is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/
#include "openturns/OT.hxx"
#include "openturns/OTtestcode.hxx"

using namespace OT;
using namespace OT::Test;

// Compute reference function value from index and point
Point computeFunctionValue(const UnsignedInteger & index, const Point & point)
{
const UnsignedInteger dimension = 3;
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
functionCollection[1] = HaarWaveletFactory();
functionCollection[2] = HaarWaveletFactory();
const LinearEnumerateFunction enumerate(dimension);
const TensorizedUniVariateFunctionFactory factory(functionCollection, enumerate);
const Function referenceFunction(factory.build(index));
const Point value(referenceFunction(point));
return value;
}

// Compute reference function value from multi-index and point
Point computeFunctionValue(const Indices & indices, const Point & point)
{
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
const Point value(computeFunctionValue(index, point));
return value;
}

int main(int, char *[])
{
TESTPREAMBLE;
OStream fullprint(std::cout);
setRandomGenerator();


try
{
// Create the orthogonal basis
fullprint << "Create the orthogonal basis" << std::endl;
UnsignedInteger dimension = 3;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
functionCollection[1] = HaarWaveletFactory();
functionCollection[2] = HaarWaveletFactory();

// Create linear enumerate function
fullprint << "Create linear enumerate function" << std::endl;
LinearEnumerateFunction enumerateFunction(dimension);
OrthogonalProductFunctionFactory productBasis(functionCollection, enumerateFunction);
fullprint << productBasis.__str__() << std::endl;
fullprint << productBasis.__repr_markdown__() << std::endl;
// Test the build() method on a collection of functions
const Point center({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function function(productBasis.build(i));
assert_almost_equal(function(center), computeFunctionValue(i, center));
// Test build from multi-index
const Indices indices(enumerateFunction(i));
const Function function2(productBasis.build(indices));
assert_almost_equal(function2(center), computeFunctionValue(indices, center));
}

// Heterogeneous collection
fullprint << "Heterogeneous collection" << std::endl;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection2(dimension);
functionCollection2[0] = HaarWaveletFactory();
functionCollection2[1] = FourierSeriesFactory();
functionCollection2[2] = HaarWaveletFactory();
OrthogonalProductFunctionFactory productBasis2(functionCollection2);
fullprint << productBasis2.__str__() << std::endl;
fullprint << productBasis2.__repr_markdown__() << std::endl;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection4(productBasis2.getFunctionFamilyCollection());
assert_equal((int) functionCollection4.getSize(), 3);

// Test getMarginal
fullprint << "Test getMarginal" << std::endl;
UnsignedInteger dimension2 = 5;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection3(dimension2);
functionCollection3[0] = HaarWaveletFactory();
functionCollection3[1] = FourierSeriesFactory();
functionCollection3[2] = HaarWaveletFactory();
functionCollection3[3] = HaarWaveletFactory();
functionCollection3[4] = FourierSeriesFactory();
OrthogonalProductFunctionFactory productBasis5(functionCollection3);
Indices indices({0, 2, 4});
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection productBasis6(productBasis5.getMarginal(indices));
assert_equal(productBasis6.getSize(), indices.getSize());
}
catch (TestFailed & ex)
{
std::cerr << ex << std::endl;
return ExitCode::Error;
}

return ExitCode::Success;
}
8 changes: 8 additions & 0 deletions lib/test/t_OrthogonalProductFunctionFactory_std.expout
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Create the orthogonal basis
Create linear enumerate function
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
Heterogeneous collection
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
Test getMarginal
51 changes: 49 additions & 2 deletions lib/test/t_OrthogonalProductPolynomialFactory_std.cxx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// -*- C++ -*-
/**
* @brief The test file of FunctionalChaosAlgoritm class
* @brief The test file of OrthogonalProductPolynomialFactory class
*
* Copyright 2005-2024 Airbus-EDF-IMACS-ONERA-Phimeca
*
Expand All @@ -24,6 +24,34 @@
using namespace OT;
using namespace OT::Test;

// Compute reference function value from index and point
Point computePolynomialValue(const UnsignedInteger & index, const Point & point)
{
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
// Compute the multi-indices using the EnumerateFunction
Indices indices(enumerate(index));
// Then build the collection of polynomials using the collection of factories
ProductPolynomialEvaluation::PolynomialCollection polynomials(dimension);
for (UnsignedInteger i = 0; i < dimension; ++i)
{
polynomials[i] = LegendreFactory().build(indices[i]);
}
const ProductPolynomialEvaluation product(polynomials);
const Point value(product(point));
return value;
}

// Compute reference function value from multi-index and point
Point computePolynomialValue(const Indices & indices, const Point & point)
{
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
const Point value(computePolynomialValue(index, point));
return value;
}

int main(int, char *[])
{
TESTPREAMBLE;
Expand All @@ -43,6 +71,18 @@ int main(int, char *[])
OrthogonalProductPolynomialFactory productBasis(polynomialCollection, enumerateFunction);
fullprint << productBasis.__str__() << std::endl;
fullprint << productBasis.__repr_markdown__() << std::endl;
// Test the build() method on a collection of functions
const Point center({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function polynomial(productBasis.build(i));
assert_almost_equal(polynomial(center), computePolynomialValue(i, center));
// Test build from multi-index
const Indices indices(enumerateFunction(i));
const Function polynomial2(productBasis.build(indices));
assert_almost_equal(polynomial2(center), computePolynomialValue(indices, center));
}

// Heterogeneous collection
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection polynomCollection2(dimension);
Expand Down Expand Up @@ -70,13 +110,20 @@ int main(int, char *[])
OrthogonalProductPolynomialFactory productBasis4(aCollection4);
fullprint << productBasis4.__str__() << std::endl;
fullprint << productBasis4.__repr_markdown__() << std::endl;

// Test getMarginal
UnsignedInteger dimension2 = 5;
Collection<Distribution> marginals4(dimension2, Uniform(0.0, 1.0));
OrthogonalProductPolynomialFactory productBasis5(marginals4);
Indices indices({0, 2, 4});
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection productBasis6(productBasis5.getMarginal(indices));
assert_equal(productBasis6.getSize(), indices.getSize());
}
catch (TestFailed & ex)
{
std::cerr << ex << std::endl;
return ExitCode::Error;
}


return ExitCode::Success;
}
15 changes: 15 additions & 0 deletions python/src/OrthogonalProductFunctionFactory_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,18 @@ Returns
-------
polynomialFamily : list of :class:`~openturns.OrthogonalUniVariateFunctionFamily`
List of orthogonal univariate function families."

// ---------------------------------------------------------------------

%feature("docstring") OT::OrthogonalProductFunctionFactory::getMarginal
"Get the marginal orthogonal functions.

Parameters
----------
indices : sequence of int, :math:`0 \leq i < n`
List of marginal indices.

Returns
-------
functionFamilylist : list of :class:`~openturns.OrthogonalUniVariateFunctionFamily`
The marginal orthogonal functions."
Loading

0 comments on commit aedf2c5

Please sign in to comment.