Skip to content

Commit

Permalink
Inherits from OrthogonalFunctionFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
mbaudin47 committed Jul 29, 2024
1 parent 6ebba5d commit 573421e
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ EnumerateFunction OrthogonalFunctionFactory::getEnumerateFunction() const
throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getEnumerateFunction() const";
}

/* Get the function factory corresponding to marginal input indices */
OrthogonalFunctionFactory OrthogonalFunctionFactory::getMarginal(const Indices & ) const
{
throw NotYetImplementedException(HERE) << "In OrthogonalFunctionFactory::getMarginal() const";
}

Bool OrthogonalFunctionFactory::isOrthogonal() const
{
return true;
}


/* String converter */
String OrthogonalFunctionFactory::__repr__() const
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,23 @@ void OrthogonalProductFunctionFactory::buildMeasure(const FunctionFamilyCollecti
measure_ = JointDistribution(distributions);
}

/* Get marginal functions */
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const
/* Get the function factory corresponding to marginal input indices */
OrthogonalFunctionFactory OrthogonalProductFunctionFactory::getMarginal(const Indices & indices) const
{
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionColl(tensorizedFunctionFactory_.getFunctionFamilyCollection());
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionColl(getFunctionFamilyCollection());
const UnsignedInteger size = functionColl.getSize();
if (!indices.check(size))
throw InvalidArgumentException(HERE) << "The indices of a marginal sample must be in the range [0, size-1] and must be different";
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionMarginalCollection;
// Create list of factories corresponding to input marginal indices
OrthogonalProductFunctionFactory::FunctionFamilyCollection functionMarginalCollection;
for (UnsignedInteger index = 0; index < size; ++ index)
if (indices.contains(index))
functionMarginalCollection.add(functionColl[index]);
return functionMarginalCollection;
// Create function
const EnumerateFunction enumerateFunction(tensorizedFunctionFactory_.getEnumerateFunction());
const EnumerateFunction marginalEnumerateFunction(enumerateFunction.getMarginal(indices));
const OrthogonalProductFunctionFactory marginalFactory(functionMarginalCollection, marginalEnumerateFunction);
return marginalFactory;
}

END_NAMESPACE_OPENTURNS
Original file line number Diff line number Diff line change
Expand Up @@ -304,17 +304,21 @@ Sample OrthogonalProductPolynomialFactory::getNodesAndWeights(const Indices & de
return nodes;
}

/* Get marginal polynomials */
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const
/* Get the function factory corresponding to the input marginal indices */
OrthogonalFunctionFactory OrthogonalProductPolynomialFactory::getMarginal(const Indices & indices) const
{
const UnsignedInteger size = coll_.getSize();
if (!indices.check(size))
throw InvalidArgumentException(HERE) << "The indices of a marginal sample must be in the range [0, size-1] and must be different";
// Create list of factories corresponding to input marginal indices
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection polynomialMarginalCollection;
for (UnsignedInteger index = 0; index < size; ++ index)
if (indices.contains(index))
polynomialMarginalCollection.add(coll_[index]);
return polynomialMarginalCollection;
// Create function
const EnumerateFunction marginalEnumerateFunction(phi_.getMarginal(indices));
const OrthogonalProductPolynomialFactory marginalFactory(polynomialMarginalCollection, marginalEnumerateFunction);
return marginalFactory;
}

END_NAMESPACE_OPENTURNS
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ public:
/** Virtual constructor */
OrthogonalFunctionFactory * clone() const override;

/** Get the function factory corresponding to marginal input indices */
virtual OrthogonalFunctionFactory getMarginal(const Indices & indices) const;

Bool isOrthogonal() const override;

/** String converter */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ public:
/** Return the collection of univariate orthogonal polynomial families */
FunctionFamilyCollection getFunctionFamilyCollection() const;

/** Get marginal functions */
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection getMarginal(const Indices & indices) const;
/** Get the function factory corresponding to marginal input indices */
OrthogonalFunctionFactory getMarginal(const Indices & indices) const override;

/** Virtual constructor */
OrthogonalProductFunctionFactory * clone() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ public:
Sample getNodesAndWeights(const Indices & degrees,
Point & weightsOut) const;

/** Get marginal functions */
PolynomialFamilyCollection getMarginal(const Indices & indices) const;
/** Get the function factory corresponding to the given input marginal indices */
using OrthogonalFunctionFactory::getMarginal;
OrthogonalFunctionFactory getMarginal(const Indices & indices) const override;

/** String converter */
String __repr__() const override;
Expand Down
18 changes: 15 additions & 3 deletions lib/test/t_OrthogonalProductFunctionFactory_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ using namespace OT::Test;

// Compute reference function value from index and point
Point computeFunctionValue(const UnsignedInteger & index, const Point & point)
{
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection functionCollection(dimension);
functionCollection[0] = HaarWaveletFactory();
Expand All @@ -42,6 +44,8 @@ Point computeFunctionValue(const UnsignedInteger & index, const Point & point)
// Compute reference function value from multi-index and point
Point computeFunctionValue(const Indices & indices, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
Expand Down Expand Up @@ -108,8 +112,16 @@ int main(int, char *[])
functionCollection3[4] = FourierSeriesFactory();
OrthogonalProductFunctionFactory productBasis5(functionCollection3);
Indices indices({0, 2, 4});
TensorizedUniVariateFunctionFactory::FunctionFamilyCollection productBasis6(productBasis5.getMarginal(indices));
assert_equal(productBasis6.getSize(), indices.getSize());
OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices));
fullprint << productBasis6.__str__() << std::endl;
// Test the build() method on a collection of functions
const Point center2({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function function(productBasis6.build(i));
assert_almost_equal(function(center2), computeFunctionValue(i, center2));
}
}
catch (TestFailed & ex)
{
Expand Down
1 change: 1 addition & 0 deletions lib/test/t_OrthogonalProductFunctionFactory_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Heterogeneous collection
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
class=OrthogonalProductFunctionFactory factory=class=TensorizedUniVariateFunctionFactory univariate function collection=[class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1,class=UniVariateFunctionFamily implementation=class=FourierSeriesFactory measure=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159,class=UniVariateFunctionFamily implementation=class=HaarWaveletFactory measure=class=Uniform name=Uniform dimension=1 a=0 b=1] enumerate function=class=LinearEnumerateFunction dimension=3 measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159 marginal[2]=class=Uniform name=Uniform dimension=1 a=0 b=1
Test getMarginal
class=OrthogonalFunctionFactory measure=class=JointDistribution name=JointDistribution dimension=3 copula=class=IndependentCopula name=IndependentCopula dimension=3 marginal[0]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[1]=class=Uniform name=Uniform dimension=1 a=0 b=1 marginal[2]=class=Uniform name=Uniform dimension=1 a=-3.14159 b=3.14159
17 changes: 15 additions & 2 deletions lib/test/t_OrthogonalProductPolynomialFactory_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ using namespace OT::Test;
// Compute reference function value from index and point
Point computePolynomialValue(const UnsignedInteger & index, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
// Compute the multi-indices using the EnumerateFunction
Expand All @@ -45,6 +47,8 @@ Point computePolynomialValue(const UnsignedInteger & index, const Point & point)
// Compute reference function value from multi-index and point
Point computePolynomialValue(const Indices & indices, const Point & point)
{
if (point.getDimension() != 3)
throw InvalidArgumentException(HERE) << "Expected a dimension 3 point, but dimension is " << point.getDimension();
const UnsignedInteger dimension = 3;
const LinearEnumerateFunction enumerate(dimension);
const UnsignedInteger index = enumerate.inverse(indices);
Expand Down Expand Up @@ -112,12 +116,21 @@ int main(int, char *[])
fullprint << productBasis4.__repr_markdown__() << std::endl;

// Test getMarginal
fullprint << "Test getMarginal" << std::endl;
UnsignedInteger dimension2 = 5;
Collection<Distribution> marginals4(dimension2, Uniform(0.0, 1.0));
OrthogonalProductPolynomialFactory productBasis5(marginals4);
Indices indices({0, 2, 4});
OrthogonalProductPolynomialFactory::PolynomialFamilyCollection productBasis6(productBasis5.getMarginal(indices));
assert_equal(productBasis6.getSize(), indices.getSize());
OrthogonalFunctionFactory productBasis6(productBasis5.getMarginal(indices));
fullprint << productBasis6.__str__() << std::endl;
// Test the build() method on a collection of functions
const Point center2({0.5, 0.5, 0.5});
for (UnsignedInteger i = 0; i < 10; ++ i)
{
// Test build from index
const Function polynomial(productBasis6.build(i));
assert_almost_equal(polynomial(center2), computePolynomialValue(i, center2));
}
}
catch (TestFailed & ex)
{
Expand Down
4 changes: 3 additions & 1 deletion python/test/t_OrthogonalProductFunctionFactory_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@
enumerateFunction,
)
productBasisMarginal = productBasis.getMarginal([0, 2, 4])
assert productBasisMarginal.getSize() == 3
for i in range(20):
function = productBasisMarginal.build(i)

4 changes: 3 additions & 1 deletion python/test/t_OrthogonalProductPolynomialFactory_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@
enumerateFunction,
)
productBasisMarginal = productBasis.getMarginal([0, 2, 4])
assert productBasisMarginal.getSize() == 3
for i in range(20):
function = productBasisMarginal.build(i)

0 comments on commit 573421e

Please sign in to comment.