From f1a382c9836038277fe23dc7f4480f74bccb3bbc Mon Sep 17 00:00:00 2001 From: Michael BAUDIN Date: Wed, 4 Sep 2024 14:03:03 +0200 Subject: [PATCH] Use OrthogonalBasis --- .../OrthogonalFunctionFactory.cxx | 5 ++- .../OrthogonalProductFunctionFactory.cxx | 2 +- .../OrthogonalProductPolynomialFactory.cxx | 2 +- .../openturns/OrthogonalFunctionFactory.hxx | 5 ++- .../OrthogonalProductFunctionFactory.hxx | 4 +- .../OrthogonalProductPolynomialFactory.hxx | 2 +- ...t_OrthogonalProductFunctionFactory_std.cxx | 43 ++++++++++++++++--- ...rthogonalProductFunctionFactory_std.expout | 2 +- ...OrthogonalProductPolynomialFactory_std.cxx | 2 +- ...hogonalProductPolynomialFactory_std.expout | 12 ++++++ 10 files changed, 63 insertions(+), 16 deletions(-) diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx index db39643f7f..889cf9dbc7 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx @@ -22,6 +22,7 @@ #include "openturns/OSS.hxx" #include "openturns/PersistentObjectFactory.hxx" #include "openturns/Exception.hxx" +#include "openturns/OrthogonalBasis.hxx" BEGIN_NAMESPACE_OPENTURNS @@ -89,9 +90,9 @@ EnumerateFunction OrthogonalFunctionFactory::getEnumerateFunction() const } /* Get the function factory corresponding to marginal input indices */ -OrthogonalFunctionFactory OrthogonalFunctionFactory::getMarginal(const Indices & ) const +OrthogonalBasis OrthogonalFunctionFactory::getMarginal(const Indices & ) const { - throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getMarginal() const"; + throw NotYetImplementedException(HERE) << "In OrthogonalBasis::getMarginal() const"; } Bool OrthogonalFunctionFactory::isOrthogonal() const diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx index 124e489805..21ce04fca4 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx @@ -154,7 +154,7 @@ void OrthogonalProductFunctionFactory::buildMeasure(const FunctionFamilyCollecti } /* Get the function factory corresponding to marginal input indices */ -OrthogonalFunctionFactory OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const +OrthogonalBasis OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const { OrthogonalProductFunctionFactory::FunctionFamilyCollection functionColl(getFunctionFamilyCollection()); const UnsignedInteger size = functionColl.getSize(); diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx index 4a54b7b412..43a1ad1231 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx @@ -305,7 +305,7 @@ Sample OrthogonalProductPolynomialFactory::getNodesAndWeights(const Indices & de } /* Get the function factory corresponding to the input marginal indices */ -OrthogonalFunctionFactory OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const +OrthogonalBasis OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const { const UnsignedInteger size = coll_.getSize(); if (!indices.check(size)) diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx index 82fbe15bb9..1e41b09e31 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx @@ -28,6 +28,8 @@ BEGIN_NAMESPACE_OPENTURNS +class OrthogonalBasis; + /** * @class OrthogonalFunctionFactory * @@ -62,8 +64,9 @@ public: /** Virtual constructor */ OrthogonalFunctionFactory * clone() const override; + /** Get the function factory corresponding to marginal input indices */ - virtual OrthogonalFunctionFactory getMarginal(const Indices & indices) const; + virtual OrthogonalBasis getMarginal(const Indices & indices) const; Bool isOrthogonal() const override; diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx index bd80619d94..158a47f67b 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx @@ -27,6 +27,8 @@ #include "openturns/OrthogonalUniVariateFunctionFamily.hxx" #include "openturns/TensorizedUniVariateFunctionFactory.hxx" #include "openturns/EnumerateFunction.hxx" +#include "openturns/OrthogonalBasis.hxx" + BEGIN_NAMESPACE_OPENTURNS @@ -72,7 +74,7 @@ public: FunctionFamilyCollection getFunctionFamilyCollection() const; /** Get the function factory corresponding to marginal input indices */ - OrthogonalFunctionFactory getMarginal(const Indices & indices) const override; + OrthogonalBasis getMarginal(const Indices & indices) const override; /** Virtual constructor */ OrthogonalProductFunctionFactory * clone() const override; diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx index e6c2760091..da892e9d29 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx @@ -83,7 +83,7 @@ public: /** Get the function factory corresponding to the given input marginal indices */ using OrthogonalFunctionFactory::getMarginal; - OrthogonalFunctionFactory getMarginal(const Indices & indices) const override; + OrthogonalBasis getMarginal(const Indices & indices) const override; /** String converter */ String __repr__() const override; diff --git a/lib/test/t_OrthogonalProductFunctionFactory_std.cxx b/lib/test/t_OrthogonalProductFunctionFactory_std.cxx index c84dc1fae2..8da06a66c9 100644 --- a/lib/test/t_OrthogonalProductFunctionFactory_std.cxx +++ b/lib/test/t_OrthogonalProductFunctionFactory_std.cxx @@ -25,7 +25,7 @@ using namespace OT; using namespace OT::Test; // Compute reference function value from index and point -Point computeFunctionValue(const UnsignedInteger & index, const Point & point) +Point computeTripleHaarFunctionValue(const UnsignedInteger & index, const Point & point) { if (point.getDimension() != 3) throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); @@ -42,17 +42,46 @@ Point computeFunctionValue(const UnsignedInteger & index, const Point & point) } // Compute reference function value from multi-index and point -Point computeFunctionValue(const Indices & indices, const Point & point) +Point computeTripleHaarFunctionValue(const Indices & indices, const Point & point) { if (point.getDimension() != 3) throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); const UnsignedInteger dimension = 3; const LinearEnumerateFunction enumerate(dimension); const UnsignedInteger index = enumerate.inverse(indices); - const Point value(computeFunctionValue(index, point)); + const Point value(computeTripleHaarFunctionValue(index, point)); return value; } +Point computeHaarFourierFunctionValue(const UnsignedInteger & index, const Point & point) +{ + if (point.getDimension() != 3) + throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); + const UnsignedInteger dimension = 3; + TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension); + functionCollection[0] = HaarWaveletFactory(); + functionCollection[1] = HaarWaveletFactory(); + functionCollection[2] = FourierSeriesFactory(); + const LinearEnumerateFunction enumerate(dimension); + const TensorizedUniVariateFunctionFactory factory(functionCollection, enumerate); + const Function referenceFunction(factory.build(index)); + const Point value(referenceFunction(point)); + return value; +} + +// Compute reference function value from multi-index and point +Point computeHaarFourierFunctionValue(const Indices & indices, const Point & point) +{ + if (point.getDimension() != 3) + throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); + const UnsignedInteger dimension = 3; + const LinearEnumerateFunction enumerate(dimension); + const UnsignedInteger index = enumerate.inverse(indices); + const Point value(computeHaarFourierFunctionValue(index, point)); + return value; +} + + int main(int, char *[]) { TESTPREAMBLE; @@ -82,11 +111,11 @@ int main(int, char *[]) { // Test build from index const Function function(productBasis.build(i)); - assert_almost_equal(function(center), computeFunctionValue(i, center)); + assert_almost_equal(function(center), computeTripleHaarFunctionValue(i, center)); // Test build from multi-index const Indices indices(enumerateFunction(i)); const Function function2(productBasis.build(indices)); - assert_almost_equal(function2(center), computeFunctionValue(indices, center)); + assert_almost_equal(function2(center), computeTripleHaarFunctionValue(indices, center)); } // Heterogeneous collection @@ -112,7 +141,7 @@ int main(int, char *[]) functionCollection3[4] = FourierSeriesFactory(); OrthogonalProductFunctionFactory productBasis5(functionCollection3); Indices indices({0, 2, 4}); - OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices)); + OrthogonalBasis productBasis6(productBasis5.getMarginal(indices)); fullprint << productBasis6.__str__() << std::endl; // Test the build() method on a collection of functions const Point center2({0.5, 0.5, 0.5}); @@ -120,7 +149,7 @@ int main(int, char *[]) { // Test build from index const Function function(productBasis6.build(i)); - assert_almost_equal(function(center2), computeFunctionValue(i, center2)); + assert_almost_equal(function(center2), computeHaarFourierFunctionValue(i, center2)); } } catch (TestFailed & ex) diff --git a/lib/test/t_OrthogonalProductFunctionFactory_std.expout b/lib/test/t_OrthogonalProductFunctionFactory_std.expout index 5e0273c002..250eddf377 100644 --- a/lib/test/t_OrthogonalProductFunctionFactory_std.expout +++ b/lib/test/t_OrthogonalProductFunctionFactory_std.expout @@ -6,4 +6,4 @@ Heterogeneous collection class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1 class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1 Test getMarginal -class=OrthogonalFunctionFactory measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 +class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 diff --git a/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx b/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx index d21edcd824..652fc4068c 100644 --- a/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx +++ b/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx @@ -121,7 +121,7 @@ int main(int, char *[]) Collection marginals4(dimension2, Uniform(0.0, 1.0)); OrthogonalProductPolynomialFactory productBasis5(marginals4); Indices indices({0, 2, 4}); - OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices)); + OrthogonalBasis productBasis6(productBasis5.getMarginal(indices)); fullprint << productBasis6.__str__() << std::endl; // Test the build() method on a collection of functions const Point center2({0.5, 0.5, 0.5}); diff --git a/lib/test/t_OrthogonalProductPolynomialFactory_std.expout b/lib/test/t_OrthogonalProductPolynomialFactory_std.expout index e9e1d6ef1f..31401121d9 100644 --- a/lib/test/t_OrthogonalProductPolynomialFactory_std.expout +++ b/lib/test/t_OrthogonalProductPolynomialFactory_std.expout @@ -86,3 +86,15 @@ OrthogonalProductPolynomialFactory | 1 | LegendreFactory | | 2 | AdaptiveStieltjesAlgorithm | +Test getMarginal +OrthogonalProductPolynomialFactory +- measure=Distribution +- isOrthogonal=true +- enumerateFunction=class=LinearEnumerateFunction dimension=3 + +| Index | Type | +|-------|-----------------| +| 0 | LegendreFactory | +| 1 | LegendreFactory | +| 2 | LegendreFactory | +