Skip to content

Commit

Permalink
Use OrthogonalBasis
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaudin47 committed Sep 4, 2024
1 parent 548655a commit f1a382c
Show file tree
Hide file tree
Showing 10 changed files with 63 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "openturns/OSS.hxx"
#include "openturns/PersistentObjectFactory.hxx"
#include "openturns/Exception.hxx"
#include "openturns/OrthogonalBasis.hxx"

BEGIN_NAMESPACE_OPENTURNS

Expand Down Expand Up @@ -89,9 +90,9 @@ EnumerateFunction OrthogonalFunctionFactory::getEnumerateFunction() const
}

/* Get the function factory corresponding to marginal input indices */
OrthogonalFunctionFactory OrthogonalFunctionFactory::getMarginal(const Indices & ) const
OrthogonalBasis OrthogonalFunctionFactory::getMarginal(const Indices & ) const
{
throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getMarginal() const";
throw NotYetImplementedException(HERE) << "In OrthogonalBasis::getMarginal() const";
}

Bool OrthogonalFunctionFactory::isOrthogonal() const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ void OrthogonalProductFunctionFactory::buildMeasure(const FunctionFamilyCollecti
}

/* Get the function factory corresponding to marginal input indices */
OrthogonalFunctionFactory OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const
OrthogonalBasis OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const
{
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionColl(getFunctionFamilyCollection());
const UnsignedInteger size = functionColl.getSize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ Sample OrthogonalProductPolynomialFactory::getNodesAndWeights(const Indices & de
}

/* Get the function factory corresponding to the input marginal indices */
OrthogonalFunctionFactory OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const
OrthogonalBasis OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const
{
const UnsignedInteger size = coll_.getSize();
if (!indices.check(size))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

BEGIN_NAMESPACE_OPENTURNS

class OrthogonalBasis;

/**
* @class OrthogonalFunctionFactory
*
Expand Down Expand Up @@ -62,8 +64,9 @@ public:
/** Virtual constructor */
OrthogonalFunctionFactory * clone() const override;


/** Get the function factory corresponding to marginal input indices */
virtual OrthogonalFunctionFactory getMarginal(const Indices & indices) const;
virtual OrthogonalBasis getMarginal(const Indices & indices) const;

Bool isOrthogonal() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "openturns/OrthogonalUniVariateFunctionFamily.hxx"
#include "openturns/TensorizedUniVariateFunctionFactory.hxx"
#include "openturns/EnumerateFunction.hxx"
#include "openturns/OrthogonalBasis.hxx"


BEGIN_NAMESPACE_OPENTURNS

Expand Down Expand Up @@ -72,7 +74,7 @@ public:
FunctionFamilyCollection getFunctionFamilyCollection() const;

/** Get the function factory corresponding to marginal input indices */
OrthogonalFunctionFactory getMarginal(const Indices & indices) const override;
OrthogonalBasis getMarginal(const Indices & indices) const override;

/** Virtual constructor */
OrthogonalProductFunctionFactory * clone() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public:

/** Get the function factory corresponding to the given input marginal indices */
using OrthogonalFunctionFactory::getMarginal;
OrthogonalFunctionFactory getMarginal(const Indices & indices) const override;
OrthogonalBasis getMarginal(const Indices & indices) const override;

/** String converter */
String __repr__() const override;
Expand Down
43 changes: 36 additions & 7 deletions lib/test/t_OrthogonalProductFunctionFactory_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace OT;
using namespace OT::Test;

// Compute reference function value from index and point
Point computeFunctionValue(const UnsignedInteger & index, const Point & point)
Point computeTripleHaarFunctionValue(const UnsignedInteger & index, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
Expand All @@ -42,17 +42,46 @@ Point computeFunctionValue(const UnsignedInteger & index, const Point & point)
}

// Compute reference function value from multi-index and point
Point computeFunctionValue(const Indices & indices, const Point & point)
Point computeTripleHaarFunctionValue(const Indices & indices, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
const Point value(computeFunctionValue(index, point));
const Point value(computeTripleHaarFunctionValue(index, point));
return value;
}

Point computeHaarFourierFunctionValue(const UnsignedInteger & index, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
functionCollection[1] = HaarWaveletFactory();
functionCollection[2] = FourierSeriesFactory();
const LinearEnumerateFunction enumerate(dimension);
const TensorizedUniVariateFunctionFactory factory(functionCollection, enumerate);
const Function referenceFunction(factory.build(index));
const Point value(referenceFunction(point));
return value;
}

// Compute reference function value from multi-index and point
Point computeHaarFourierFunctionValue(const Indices & indices, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
const Point value(computeHaarFourierFunctionValue(index, point));
return value;
}


int main(int, char *[])
{
TESTPREAMBLE;
Expand Down Expand Up @@ -82,11 +111,11 @@ int main(int, char *[])
{
// Test build from index
const Function function(productBasis.build(i));
assert_almost_equal(function(center), computeFunctionValue(i, center));
assert_almost_equal(function(center), computeTripleHaarFunctionValue(i, center));
// Test build from multi-index
const Indices indices(enumerateFunction(i));
const Function function2(productBasis.build(indices));
assert_almost_equal(function2(center), computeFunctionValue(indices, center));
assert_almost_equal(function2(center), computeTripleHaarFunctionValue(indices, center));
}

// Heterogeneous collection
Expand All @@ -112,15 +141,15 @@ int main(int, char *[])
functionCollection3[4] = FourierSeriesFactory();
OrthogonalProductFunctionFactory productBasis5(functionCollection3);
Indices indices({0, 2, 4});
OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices));
OrthogonalBasis productBasis6(productBasis5.getMarginal(indices));
fullprint << productBasis6.__str__() << std::endl;
// Test the build() method on a collection of functions
const Point center2({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function function(productBasis6.build(i));
assert_almost_equal(function(center2), computeFunctionValue(i, center2));
assert_almost_equal(function(center2), computeHaarFourierFunctionValue(i, center2));
}
}
catch (TestFailed & ex)
Expand Down
2 changes: 1 addition & 1 deletion lib/test/t_OrthogonalProductFunctionFactory_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ Heterogeneous collection
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
Test getMarginal
class=OrthogonalFunctionFactory measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159
2 changes: 1 addition & 1 deletion lib/test/t_OrthogonalProductPolynomialFactory_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ int main(int, char *[])
Collection<Distribution> marginals4(dimension2, Uniform(0.0, 1.0));
OrthogonalProductPolynomialFactory productBasis5(marginals4);
Indices indices({0, 2, 4});
OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices));
OrthogonalBasis productBasis6(productBasis5.getMarginal(indices));
fullprint << productBasis6.__str__() << std::endl;
// Test the build() method on a collection of functions
const Point center2({0.5, 0.5, 0.5});
Expand Down
12 changes: 12 additions & 0 deletions lib/test/t_OrthogonalProductPolynomialFactory_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,15 @@ OrthogonalProductPolynomialFactory
| 1 | LegendreFactory |
| 2 | AdaptiveStieltjesAlgorithm |

Test getMarginal
OrthogonalProductPolynomialFactory
- measure=Distribution
- isOrthogonal=true
- enumerateFunction=class=LinearEnumerateFunction dimension=3

| Index | Type |
|-------|-----------------|
| 0 | LegendreFactory |
| 1 | LegendreFactory |
| 2 | LegendreFactory |

0 comments on commit f1a382c

Please sign in to comment.