diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx index 41c746f2f6..db39643f7f 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalFunctionFactory.cxx @@ -88,13 +88,17 @@ EnumerateFunction OrthogonalFunctionFactory::getEnumerateFunction() const throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getEnumerateFunction() const"; } +/* Get the function factory corresponding to marginal input indices */ +OrthogonalFunctionFactory OrthogonalFunctionFactory::getMarginal(const Indices & ) const +{ + throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getMarginal() const"; +} Bool OrthogonalFunctionFactory::isOrthogonal() const { return true; } - /* String converter */ String OrthogonalFunctionFactory::__repr__() const { diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx index 8cf5e43651..124e489805 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductFunctionFactory.cxx @@ -153,18 +153,23 @@ void OrthogonalProductFunctionFactory::buildMeasure(const FunctionFamilyCollecti measure_ = JointDistribution(distributions); } -/* Get marginal functions */ -TensorizedUniVariateFunctionFactory::FunctionFamilyCollection OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const +/* Get the function factory corresponding to marginal input indices */ +OrthogonalFunctionFactory OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const { - TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionColl(tensorizedFunctionFactory_.getFunctionFamilyCollection()); + OrthogonalProductFunctionFactory::FunctionFamilyCollection functionColl(getFunctionFamilyCollection()); const UnsignedInteger size = functionColl.getSize(); if (!indices.check(size)) throw InvalidArgumentException(HERE) << "The indices of a marginal sample must be in the range [0, size-1] and must be different"; - TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionMarginalCollection; + // Create list of factories corresponding to input marginal indices + OrthogonalProductFunctionFactory::FunctionFamilyCollection functionMarginalCollection; for (UnsignedInteger index = 0; index < size; ++ index) if (indices.contains(index)) functionMarginalCollection.add(functionColl[index]); - return functionMarginalCollection; + // Create function + const EnumerateFunction enumerateFunction(tensorizedFunctionFactory_.getEnumerateFunction()); + const EnumerateFunction marginalEnumerateFunction(enumerateFunction.getMarginal(indices)); + const OrthogonalProductFunctionFactory marginalFactory(functionMarginalCollection, marginalEnumerateFunction); + return marginalFactory; } END_NAMESPACE_OPENTURNS diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx index 1952b2060c..4a54b7b412 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/OrthogonalProductPolynomialFactory.cxx @@ -304,17 +304,21 @@ Sample OrthogonalProductPolynomialFactory::getNodesAndWeights(const Indices & de return nodes; } -/* Get marginal polynomials */ -OrthogonalProductPolynomialFactory::PolynomialFamilyCollection OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const +/* Get the function factory corresponding to the input marginal indices */ +OrthogonalFunctionFactory OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const { const UnsignedInteger size = coll_.getSize(); if (!indices.check(size)) throw InvalidArgumentException(HERE) << "The indices of a marginal sample must be in the range [0, size-1] and must be different"; + // Create list of factories corresponding to input marginal indices OrthogonalProductPolynomialFactory::PolynomialFamilyCollection polynomialMarginalCollection; for (UnsignedInteger index = 0; index < size; ++ index) if (indices.contains(index)) polynomialMarginalCollection.add(coll_[index]); - return polynomialMarginalCollection; + // Create function + const EnumerateFunction marginalEnumerateFunction(phi_.getMarginal(indices)); + const OrthogonalProductPolynomialFactory marginalFactory(polynomialMarginalCollection, marginalEnumerateFunction); + return marginalFactory; } END_NAMESPACE_OPENTURNS diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx index c622202bcf..82fbe15bb9 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalFunctionFactory.hxx @@ -62,6 +62,9 @@ public: /** Virtual constructor */ OrthogonalFunctionFactory * clone() const override; + /** Get the function factory corresponding to marginal input indices */ + virtual OrthogonalFunctionFactory getMarginal(const Indices & indices) const; + Bool isOrthogonal() const override; /** String converter */ diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx index d69349909e..bd80619d94 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductFunctionFactory.hxx @@ -71,8 +71,8 @@ public: /** Return the collection of univariate orthogonal polynomial families */ FunctionFamilyCollection getFunctionFamilyCollection() const; - /** Get marginal functions */ - TensorizedUniVariateFunctionFactory::FunctionFamilyCollection getMarginal(const Indices & indices) const; + /** Get the function factory corresponding to marginal input indices */ + OrthogonalFunctionFactory getMarginal(const Indices & indices) const override; /** Virtual constructor */ OrthogonalProductFunctionFactory * clone() const override; diff --git a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx index b79c49876f..e6c2760091 100644 --- a/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx +++ b/lib/src/Uncertainty/Algorithm/OrthogonalBasis/openturns/OrthogonalProductPolynomialFactory.hxx @@ -81,8 +81,9 @@ public: Sample getNodesAndWeights(const Indices & degrees, Point & weightsOut) const; - /** Get marginal functions */ - PolynomialFamilyCollection getMarginal(const Indices & indices) const; + /** Get the function factory corresponding to the given input marginal indices */ + using OrthogonalFunctionFactory::getMarginal; + OrthogonalFunctionFactory getMarginal(const Indices & indices) const override; /** String converter */ String __repr__() const override; diff --git a/lib/test/t_OrthogonalProductFunctionFactory_std.cxx b/lib/test/t_OrthogonalProductFunctionFactory_std.cxx index b6de8708fb..c84dc1fae2 100644 --- a/lib/test/t_OrthogonalProductFunctionFactory_std.cxx +++ b/lib/test/t_OrthogonalProductFunctionFactory_std.cxx @@ -26,7 +26,9 @@ using namespace OT::Test; // Compute reference function value from index and point Point computeFunctionValue(const UnsignedInteger & index, const Point & point) -{ +{ + if (point.getDimension() != 3) + throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); const UnsignedInteger dimension = 3; TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension); functionCollection[0] = HaarWaveletFactory(); @@ -42,6 +44,8 @@ Point computeFunctionValue(const UnsignedInteger & index, const Point & point) // Compute reference function value from multi-index and point Point computeFunctionValue(const Indices & indices, const Point & point) { + if (point.getDimension() != 3) + throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); const UnsignedInteger dimension = 3; const LinearEnumerateFunction enumerate(dimension); const UnsignedInteger index = enumerate.inverse(indices); @@ -108,8 +112,16 @@ int main(int, char *[]) functionCollection3[4] = FourierSeriesFactory(); OrthogonalProductFunctionFactory productBasis5(functionCollection3); Indices indices({0, 2, 4}); - TensorizedUniVariateFunctionFactory::FunctionFamilyCollection productBasis6(productBasis5.getMarginal(indices)); - assert_equal(productBasis6.getSize(), indices.getSize()); + OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices)); + fullprint << productBasis6.__str__() << std::endl; + // Test the build() method on a collection of functions + const Point center2({0.5, 0.5, 0.5}); + for (UnsignedInteger i = 0; i < 10; ++ i) + { + // Test build from index + const Function function(productBasis6.build(i)); + assert_almost_equal(function(center2), computeFunctionValue(i, center2)); + } } catch (TestFailed & ex) { diff --git a/lib/test/t_OrthogonalProductFunctionFactory_std.expout b/lib/test/t_OrthogonalProductFunctionFactory_std.expout index 52bf12dd5b..5e0273c002 100644 --- a/lib/test/t_OrthogonalProductFunctionFactory_std.expout +++ b/lib/test/t_OrthogonalProductFunctionFactory_std.expout @@ -6,3 +6,4 @@ Heterogeneous collection class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1 class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1 Test getMarginal +class=OrthogonalFunctionFactory measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 diff --git a/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx b/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx index 11245f0dc0..d21edcd824 100644 --- a/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx +++ b/lib/test/t_OrthogonalProductPolynomialFactory_std.cxx @@ -27,6 +27,8 @@ using namespace OT::Test; // Compute reference function value from index and point Point computePolynomialValue(const UnsignedInteger & index, const Point & point) { + if (point.getDimension() != 3) + throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); const UnsignedInteger dimension = 3; const LinearEnumerateFunction enumerate(dimension); // Compute the multi-indices using the EnumerateFunction @@ -45,6 +47,8 @@ Point computePolynomialValue(const UnsignedInteger & index, const Point & point) // Compute reference function value from multi-index and point Point computePolynomialValue(const Indices & indices, const Point & point) { + if (point.getDimension() != 3) + throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension(); const UnsignedInteger dimension = 3; const LinearEnumerateFunction enumerate(dimension); const UnsignedInteger index = enumerate.inverse(indices); @@ -112,12 +116,21 @@ int main(int, char *[]) fullprint << productBasis4.__repr_markdown__() << std::endl; // Test getMarginal + fullprint << "Test getMarginal" << std::endl; UnsignedInteger dimension2 = 5; Collection marginals4(dimension2, Uniform(0.0, 1.0)); OrthogonalProductPolynomialFactory productBasis5(marginals4); Indices indices({0, 2, 4}); - OrthogonalProductPolynomialFactory::PolynomialFamilyCollection productBasis6(productBasis5.getMarginal(indices)); - assert_equal(productBasis6.getSize(), indices.getSize()); + OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices)); + fullprint << productBasis6.__str__() << std::endl; + // Test the build() method on a collection of functions + const Point center2({0.5, 0.5, 0.5}); + for (UnsignedInteger i = 0; i < 10; ++ i) + { + // Test build from index + const Function polynomial(productBasis6.build(i)); + assert_almost_equal(polynomial(center2), computePolynomialValue(i, center2)); + } } catch (TestFailed & ex) { diff --git a/python/test/t_OrthogonalProductFunctionFactory_std.py b/python/test/t_OrthogonalProductFunctionFactory_std.py index 3fdb6c9be7..acf3bfd01d 100644 --- a/python/test/t_OrthogonalProductFunctionFactory_std.py +++ b/python/test/t_OrthogonalProductFunctionFactory_std.py @@ -37,4 +37,6 @@ enumerateFunction, ) productBasisMarginal = productBasis.getMarginal([0, 2, 4]) -assert productBasisMarginal.getSize() == 3 +for i in range(20): + function = productBasisMarginal.build(i) + diff --git a/python/test/t_OrthogonalProductPolynomialFactory_std.py b/python/test/t_OrthogonalProductPolynomialFactory_std.py index 3b58330b7f..72ccab9428 100755 --- a/python/test/t_OrthogonalProductPolynomialFactory_std.py +++ b/python/test/t_OrthogonalProductPolynomialFactory_std.py @@ -37,4 +37,6 @@ enumerateFunction, ) productBasisMarginal = productBasis.getMarginal([0, 2, 4]) -assert productBasisMarginal.getSize() == 3 +for i in range(20): + function = productBasisMarginal.build(i) +