Skip to content

Commit

Permalink
New OrthogonalProduct*Factory.getMarginal()
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaudin47 committed Sep 4, 2024
1 parent 6fcf8ba commit d148b94
Show file tree
Hide file tree
Showing 19 changed files with 662 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "openturns/OSS.hxx"
#include "openturns/PersistentObjectFactory.hxx"
#include "openturns/Exception.hxx"
#include "openturns/OrthogonalBasis.hxx"

BEGIN_NAMESPACE_OPENTURNS

Expand Down Expand Up @@ -88,13 +89,17 @@ EnumerateFunction OrthogonalFunctionFactory::getEnumerateFunction() const
throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getEnumerateFunction() const";
}

/* Get the function factory corresponding to marginal input indices */
OrthogonalBasis OrthogonalFunctionFactory::getMarginal(const Indices & ) const
{
throw NotYetImplementedException(HERE) << "In OrthogonalBasis::getMarginal() const";
}

Bool OrthogonalFunctionFactory::isOrthogonal() const
{
return true;
}


/* String converter */
String OrthogonalFunctionFactory::__repr__() const
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

BEGIN_NAMESPACE_OPENTURNS



TEMPLATE_CLASSNAMEINIT(PersistentCollection<OrthogonalUniVariateFunctionFamily>)

static const Factory<PersistentCollection<OrthogonalUniVariateFunctionFamily> > Factory_PersistentCollection_OrthogonalUniVariateFunctionFamily;
Expand All @@ -53,7 +51,6 @@ OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory()
// Nothing to do
}


/* Constructor */
OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const FunctionFamilyCollection & coll)
: OrthogonalFunctionFactory()
Expand All @@ -62,7 +59,6 @@ OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const Functio
buildMeasure(coll);
}


/* Constructor */
OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const FunctionFamilyCollection & coll,
const EnumerateFunction & phi)
Expand All @@ -73,14 +69,12 @@ OrthogonalProductFunctionFactory::OrthogonalProductFunctionFactory(const Functio
buildMeasure(coll);
}


/* Virtual constructor */
OrthogonalProductFunctionFactory * OrthogonalProductFunctionFactory::clone() const
{
return new OrthogonalProductFunctionFactory(*this);
}


/* Return the enumerate function that translate unidimensional indices into multidimensional indices */
EnumerateFunction OrthogonalProductFunctionFactory::getEnumerateFunction() const
{
Expand All @@ -99,13 +93,17 @@ OrthogonalProductFunctionFactory::FunctionFamilyCollection OrthogonalProductFunc
return coll;
}


/* Build the Function of the given index */
Function OrthogonalProductFunctionFactory::build(const UnsignedInteger index) const
{
return tensorizedFunctionFactory_.build(index);
}

/* Build the Function of the given index */
Function OrthogonalProductFunctionFactory::build(const Indices & indices) const
{
return tensorizedFunctionFactory_.build(getEnumerateFunction().inverse(indices));
}

/* String converter */
String OrthogonalProductFunctionFactory::__repr__() const
Expand All @@ -115,15 +113,13 @@ String OrthogonalProductFunctionFactory::__repr__() const
<< " measure=" << measure_;
}


/* Method save() stores the object through the StorageManager */
void OrthogonalProductFunctionFactory::save(Advocate & adv) const
{
OrthogonalFunctionFactory::save(adv);
adv.saveAttribute("tensorizedFunctionFactory_", tensorizedFunctionFactory_);
}


/* Method load() reloads the object from the StorageManager */
void OrthogonalProductFunctionFactory::load(Advocate & adv)
{
Expand Down Expand Up @@ -157,5 +153,23 @@ void OrthogonalProductFunctionFactory::buildMeasure(const FunctionFamilyCollecti
measure_ = JointDistribution(distributions);
}

/* Get the function factory corresponding to marginal input indices */
OrthogonalBasis OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const
{
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionColl(getFunctionFamilyCollection());
const UnsignedInteger size = functionColl.getSize();
if (!indices.check(size))
throw InvalidArgumentException(HERE) << "The indices of a marginal sample must be in the range [0, size-1] and must be different";
// Create list of factories corresponding to input marginal indices
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionMarginalCollection;
for (UnsignedInteger index = 0; index < size; ++ index)
if (indices.contains(index))
functionMarginalCollection.add(functionColl[index]);
// Create function
const EnumerateFunction enumerateFunction(tensorizedFunctionFactory_.getEnumerateFunction());
const EnumerateFunction marginalEnumerateFunction(enumerateFunction.getMarginal(indices));
const OrthogonalProductFunctionFactory marginalFactory(functionMarginalCollection, marginalEnumerateFunction);
return marginalFactory;
}

END_NAMESPACE_OPENTURNS
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,21 @@ Sample OrthogonalProductPolynomialFactory::getNodesAndWeights(const Indices & de
return nodes;
}

/* Get the function factory corresponding to the input marginal indices */
OrthogonalBasis OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const
{
const UnsignedInteger size = coll_.getSize();
if (!indices.check(size))
throw InvalidArgumentException(HERE) << "The indices of a marginal sample must be in the range [0, size-1] and must be different";
// Create list of factories corresponding to input marginal indices
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection polynomialMarginalCollection;
for (UnsignedInteger index = 0; index < size; ++ index)
if (indices.contains(index))
polynomialMarginalCollection.add(coll_[index]);
// Create function
const EnumerateFunction marginalEnumerateFunction(phi_.getMarginal(indices));
const OrthogonalProductPolynomialFactory marginalFactory(polynomialMarginalCollection, marginalEnumerateFunction);
return marginalFactory;
}

END_NAMESPACE_OPENTURNS
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

BEGIN_NAMESPACE_OPENTURNS

class OrthogonalBasis;

/**
* @class OrthogonalFunctionFactory
*
Expand All @@ -51,7 +53,7 @@ public:
Function build(const UnsignedInteger index) const override;

/** Build the Function of the given multi-indices */
Function build(const Indices & indices) const;
virtual Function build(const Indices & indices) const;

/** Return the measure upon which the basis is orthogonal */
virtual Distribution getMeasure() const;
Expand All @@ -62,6 +64,10 @@ public:
/** Virtual constructor */
OrthogonalFunctionFactory * clone() const override;


/** Get the function factory corresponding to marginal input indices */
virtual OrthogonalBasis getMarginal(const Indices & indices) const;

Bool isOrthogonal() const override;

/** String converter */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "openturns/OrthogonalUniVariateFunctionFamily.hxx"
#include "openturns/TensorizedUniVariateFunctionFactory.hxx"
#include "openturns/EnumerateFunction.hxx"
#include "openturns/OrthogonalBasis.hxx"


BEGIN_NAMESPACE_OPENTURNS

Expand Down Expand Up @@ -62,12 +64,18 @@ public:
/** Build the Function of the given index */
Function build(const UnsignedInteger index) const override;

/** Build the Function of the given multi-indices */
Function build(const Indices & indices) const override;

/** Return the enumerate function that translate unidimensional indices nto multidimensional indices */
EnumerateFunction getEnumerateFunction() const override;

/** Return the collection of univariate orthogonal polynomial families */
FunctionFamilyCollection getFunctionFamilyCollection() const;

/** Get the function factory corresponding to marginal input indices */
OrthogonalBasis getMarginal(const Indices & indices) const override;

/** Virtual constructor */
OrthogonalProductFunctionFactory * clone() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#define OPENTURNS_ORTHOGONALPRODUCTPOLYNOMIALFACTORY_HXX

#include "openturns/OrthogonalFunctionFactory.hxx"
#include "openturns/OrthogonalProductFunctionFactory.hxx"
#include "openturns/Distribution.hxx"
#include "openturns/Indices.hxx"
#include "openturns/Point.hxx"
Expand Down Expand Up @@ -80,6 +81,10 @@ public:
Sample getNodesAndWeights(const Indices & degrees,
Point & weightsOut) const;

/** Get the function factory corresponding to the given input marginal indices */
using OrthogonalFunctionFactory::getMarginal;
OrthogonalBasis getMarginal(const Indices & indices) const override;

/** String converter */
String __repr__() const override;
String __str__(const String & offset = "") const override;
Expand Down
1 change: 1 addition & 0 deletions lib/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,7 @@ ot_check_test (LinearModelAnalysis_std)
ot_check_test (LinearModelValidation_std IGNOREOUT)
ot_check_test (KrigingAlgorithm_isotropic_std IGNOREOUT)
ot_check_test (OrthogonalProductPolynomialFactory_std)
ot_check_test (OrthogonalProductFunctionFactory_std)

if (HMAT_FOUND)
ot_check_test (KrigingAlgorithm_std_hmat)
Expand Down
162 changes: 162 additions & 0 deletions lib/test/t_OrthogonalProductFunctionFactory_std.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// -*- C++ -*-
/**
* @brief The test file of OrthogonalProductFunctionFactory class
*
* Copyright 2005-2024 Airbus-EDF-IMACS-ONERA-Phimeca
*
* This library is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/
#include "openturns/OT.hxx"
#include "openturns/OTtestcode.hxx"

using namespace OT;
using namespace OT::Test;

// Compute reference function value from index and point
Point computeTripleHaarFunctionValue(const UnsignedInteger & index, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
functionCollection[1] = HaarWaveletFactory();
functionCollection[2] = HaarWaveletFactory();
const LinearEnumerateFunction enumerate(dimension);
const TensorizedUniVariateFunctionFactory factory(functionCollection, enumerate);
const Function referenceFunction(factory.build(index));
const Point value(referenceFunction(point));
return value;
}

// Compute reference function value from multi-index and point
Point computeTripleHaarFunctionValue(const Indices & indices, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
const Point value(computeTripleHaarFunctionValue(index, point));
return value;
}

Point computeHaarFourierFunctionValue(const UnsignedInteger & index, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
functionCollection[1] = HaarWaveletFactory();
functionCollection[2] = FourierSeriesFactory();
const LinearEnumerateFunction enumerate(dimension);
const TensorizedUniVariateFunctionFactory factory(functionCollection, enumerate);
const Function referenceFunction(factory.build(index));
const Point value(referenceFunction(point));
return value;
}

// Compute reference function value from multi-index and point
Point computeHaarFourierFunctionValue(const Indices & indices, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
const Point value(computeHaarFourierFunctionValue(index, point));
return value;
}


int main(int, char *[])
{
TESTPREAMBLE;
OStream fullprint(std::cout);
setRandomGenerator();


try
{
// Create the orthogonal basis
fullprint << "Create the orthogonal basis" << std::endl;
UnsignedInteger dimension = 3;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
functionCollection[1] = HaarWaveletFactory();
functionCollection[2] = HaarWaveletFactory();

// Create linear enumerate function
fullprint << "Create linear enumerate function" << std::endl;
LinearEnumerateFunction enumerateFunction(dimension);
OrthogonalProductFunctionFactory productBasis(functionCollection, enumerateFunction);
fullprint << productBasis.__str__() << std::endl;
fullprint << productBasis.__repr_markdown__() << std::endl;
// Test the build() method on a collection of functions
const Point center({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function function(productBasis.build(i));
assert_almost_equal(function(center), computeTripleHaarFunctionValue(i, center));
// Test build from multi-index
const Indices indices(enumerateFunction(i));
const Function function2(productBasis.build(indices));
assert_almost_equal(function2(center), computeTripleHaarFunctionValue(indices, center));
}

// Heterogeneous collection
fullprint << "Heterogeneous collection" << std::endl;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection2(dimension);
functionCollection2[0] = HaarWaveletFactory();
functionCollection2[1] = FourierSeriesFactory();
functionCollection2[2] = HaarWaveletFactory();
OrthogonalProductFunctionFactory productBasis2(functionCollection2);
fullprint << productBasis2.__str__() << std::endl;
fullprint << productBasis2.__repr_markdown__() << std::endl;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection4(productBasis2.getFunctionFamilyCollection());
assert_equal((int) functionCollection4.getSize(), 3);

// Test getMarginal
fullprint << "Test getMarginal" << std::endl;
UnsignedInteger dimension2 = 5;
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionCollection3(dimension2);
functionCollection3[0] = HaarWaveletFactory();
functionCollection3[1] = FourierSeriesFactory();
functionCollection3[2] = HaarWaveletFactory();
functionCollection3[3] = HaarWaveletFactory();
functionCollection3[4] = FourierSeriesFactory();
OrthogonalProductFunctionFactory productBasis5(functionCollection3);
Indices indices({0, 2, 4});
OrthogonalBasis productBasis6(productBasis5.getMarginal(indices));
fullprint << productBasis6.__str__() << std::endl;
// Test the build() method on a collection of functions
const Point center2({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function function(productBasis6.build(i));
assert_almost_equal(function(center2), computeHaarFourierFunctionValue(i, center2));
}
}
catch (TestFailed & ex)
{
std::cerr << ex << std::endl;
return ExitCode::Error;
}

return ExitCode::Success;
}
Loading

0 comments on commit d148b94

Please sign in to comment.