Skip to content

Commit

Permalink
New EnumerateFunction*.getMarginal()
Browse files Browse the repository at this point in the history
New LinearEnumerateFunction.getMarginal()
New HyperbolicAnisotropicEnumerateFunction.getMarginal()
New NormInfEnumerateFunction.getMarginal()
  • Loading branch information
mbaudin47 committed Jun 20, 2024
1 parent a7e812e commit f318729
Show file tree
Hide file tree
Showing 23 changed files with 279 additions and 0 deletions.
10 changes: 10 additions & 0 deletions lib/src/Base/Func/EnumerateFunction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,14 @@ Indices EnumerateFunction::getUpperBound() const
return getImplementation()->getUpperBound();
}

EnumerateFunction EnumerateFunction::getMarginal(const Indices & indices) const
{
return getImplementation()->getMarginal(indices);
}

EnumerateFunction EnumerateFunction::getMarginal(const UnsignedInteger i) const
{
return getImplementation()->getMarginal(i);
}

END_NAMESPACE_OPENTURNS
12 changes: 12 additions & 0 deletions lib/src/Base/Func/EnumerateFunctionImplementation.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*
*/
#include <limits>
#include "openturns/EnumerateFunction.hxx"
#include "openturns/EnumerateFunctionImplementation.hxx"
#include "openturns/OSS.hxx"
#include "openturns/PersistentObjectFactory.hxx"
Expand Down Expand Up @@ -148,5 +149,16 @@ void EnumerateFunctionImplementation::load(Advocate & adv)
upperBound_ = Indices(getDimension(), std::numeric_limits<UnsignedInteger>::max());
}

/* Returns the marginal enumerate function */
EnumerateFunction EnumerateFunctionImplementation::getMarginal(const Indices &) const
{
throw NotYetImplementedException(HERE) << "In EnumerateFunctionImplementation::getMarginal";
}

/* Returns the marginal enumerate function */
EnumerateFunction EnumerateFunctionImplementation::getMarginal(const UnsignedInteger i) const
{
return getMarginal(Indices({i}));
}

END_NAMESPACE_OPENTURNS
13 changes: 13 additions & 0 deletions lib/src/Base/Func/HyperbolicAnisotropicEnumerateFunction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*
*/
#include <algorithm>
#include "openturns/EnumerateFunction.hxx"
#include "openturns/HyperbolicAnisotropicEnumerateFunction.hxx"
#include "openturns/OSS.hxx"
#include "openturns/PersistentObjectFactory.hxx"
Expand Down Expand Up @@ -275,6 +276,18 @@ void HyperbolicAnisotropicEnumerateFunction::setUpperBound(const Indices & upper
initialize();
}

/* The marginal enumerate function */
EnumerateFunction HyperbolicAnisotropicEnumerateFunction::getMarginal(const Indices & indices) const
{
const UnsignedInteger inputDimension = getDimension();
indices.check(inputDimension);
const UnsignedInteger activeDimension = indices.getSize();
Point weightMarginal(activeDimension);
for (UnsignedInteger i = 0; i < activeDimension; ++i)
weightMarginal[i] = weight_[indices[i]];
const HyperbolicAnisotropicEnumerateFunction enumerateFunctionMarginal(weightMarginal, q_);
return enumerateFunctionMarginal;
}

/* Method save() stores the object through the StorageManager */
void HyperbolicAnisotropicEnumerateFunction::save(Advocate & adv) const
Expand Down
11 changes: 11 additions & 0 deletions lib/src/Base/Func/LinearEnumerateFunction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/
#include "openturns/EnumerateFunction.hxx"
#include "openturns/LinearEnumerateFunction.hxx"
#include "openturns/OSS.hxx"
#include "openturns/PersistentObjectFactory.hxx"
Expand Down Expand Up @@ -174,6 +175,16 @@ void LinearEnumerateFunction::setUpperBound(const Indices & /*upperBound*/)
throw NotYetImplementedException(HERE) << " in LinearEnumerateFunction::setUpperBound";
}

/* The marginal enumerate function */
EnumerateFunction LinearEnumerateFunction::getMarginal(const Indices & indices) const
{
const UnsignedInteger inputDimension = getDimension();
indices.check(inputDimension);
const UnsignedInteger activeDimension = indices.getSize();
const LinearEnumerateFunction enumerateFunctionMarginal(activeDimension);
return enumerateFunctionMarginal;
}

/* Method save() stores the object through the StorageManager */
void LinearEnumerateFunction::save(Advocate & adv) const
{
Expand Down
11 changes: 11 additions & 0 deletions lib/src/Base/Func/NormInfEnumerateFunction.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* along with this library. If not, see <http://www.gnu.org/licenses/>.
*
*/
#include "openturns/EnumerateFunction.hxx"
#include "openturns/NormInfEnumerateFunction.hxx"
#include "openturns/OSS.hxx"
#include "openturns/PersistentObjectFactory.hxx"
Expand Down Expand Up @@ -143,6 +144,16 @@ UnsignedInteger NormInfEnumerateFunction::getMaximumDegreeStrataIndex(const Unsi
return maximumDegree / getDimension();
}

/* The marginal enumerate function */
EnumerateFunction NormInfEnumerateFunction::getMarginal(const Indices & indices) const
{
const UnsignedInteger inputDimension = getDimension();
indices.check(inputDimension);
const UnsignedInteger activeDimension = indices.getSize();
const NormInfEnumerateFunction enumerateFunctionMarginal(activeDimension);
return enumerateFunctionMarginal;
}

/* Method save() stores the object through the StorageManager */
void NormInfEnumerateFunction::save(Advocate & adv) const
{
Expand Down
6 changes: 6 additions & 0 deletions lib/src/Base/Func/openturns/EnumerateFunction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ public:
/** Basis size from degree */
UnsignedInteger getBasisSizeFromTotalDegree(const UnsignedInteger maximumDegree) const;

/** The marginal enumerate function */
EnumerateFunction getMarginal(const Indices & indices) const;

/** The marginal enumerate function */
EnumerateFunction getMarginal(const UnsignedInteger i) const;

/** Dimension accessor */
void setDimension(const UnsignedInteger dimension);
UnsignedInteger getDimension() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@

BEGIN_NAMESPACE_OPENTURNS

// Forward declaration
class EnumerateFunction;

/**
* @class EnumerateFunctionImplementation
*
Expand Down Expand Up @@ -69,6 +72,12 @@ public:
/** Basis size from total degree */
virtual UnsignedInteger getBasisSizeFromTotalDegree(const UnsignedInteger maximumDegree) const;

/** The marginal enumerate function */
virtual EnumerateFunction getMarginal(const Indices & indices) const;

/** The marginal enumerate function */
virtual EnumerateFunction getMarginal(const UnsignedInteger i) const;

/** Dimension accessor */
void setDimension(const UnsignedInteger dimension);
UnsignedInteger getDimension() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ public:
/** Upper bound accessor */
void setUpperBound(const Indices & upperBound) override;

/** The marginal enumerate function */
using EnumerateFunctionImplementation::getMarginal;
EnumerateFunction getMarginal(const Indices & indices) const override;

/** Method save() stores the object through the StorageManager */
void save(Advocate & adv) const override;

Expand Down
4 changes: 4 additions & 0 deletions lib/src/Base/Func/openturns/LinearEnumerateFunction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ public:
/** Upper bound accessor */
void setUpperBound(const Indices & upperBound) override;

/** The marginal enumerate function */
using EnumerateFunctionImplementation::getMarginal;
EnumerateFunction getMarginal(const Indices & indices) const override;

/** Method save() stores the object through the StorageManager */
void save(Advocate & adv) const override;

Expand Down
4 changes: 4 additions & 0 deletions lib/src/Base/Func/openturns/NormInfEnumerateFunction.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ public:
/** The index of the strata of degree max <= maximumDegree */
UnsignedInteger getMaximumDegreeStrataIndex(const UnsignedInteger maximumDegree) const override;

/** The marginal enumerate function */
using EnumerateFunctionImplementation::getMarginal;
EnumerateFunction getMarginal(const Indices & indices) const override;

/** Method save() stores the object through the StorageManager */
void save(Advocate & adv) const override;

Expand Down
11 changes: 11 additions & 0 deletions lib/test/t_HyperbolicAnisotropicEnumerateFunction_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ int main(int, char *[])
}
}

// Test getMarginal()
fullprint << "Test getMarginal()" << std::endl;
HyperbolicAnisotropicEnumerateFunction enumerateFunction(10, 0.5);
Indices indices({0, 2, 4, 6, 9});
EnumerateFunction marginalEnumerate(enumerateFunction.getMarginal(indices));
assert_equal(marginalEnumerate.getDimension(), indices.getSize());
for (UnsignedInteger index = 0; index < size; ++index)
{
Indices multiIndex(marginalEnumerate(index));
fullprint << "index=" << index << ", multi-index=" << multiIndex << std::endl;
}


}
Expand Down
26 changes: 26 additions & 0 deletions lib/test/t_HyperbolicAnisotropicEnumerateFunction_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,29 @@ index=23 [0,0,6,0]
index=24 [0,0,0,6]
And first 5 strata cardinals :[1,4,4,4,4]

Test getMarginal()
index=0, multi-index=[0,0,0,0,0]
index=1, multi-index=[1,0,0,0,0]
index=2, multi-index=[0,1,0,0,0]
index=3, multi-index=[0,0,1,0,0]
index=4, multi-index=[0,0,0,1,0]
index=5, multi-index=[0,0,0,0,1]
index=6, multi-index=[2,0,0,0,0]
index=7, multi-index=[0,2,0,0,0]
index=8, multi-index=[0,0,2,0,0]
index=9, multi-index=[0,0,0,2,0]
index=10, multi-index=[0,0,0,0,2]
index=11, multi-index=[3,0,0,0,0]
index=12, multi-index=[0,3,0,0,0]
index=13, multi-index=[0,0,3,0,0]
index=14, multi-index=[0,0,0,3,0]
index=15, multi-index=[0,0,0,0,3]
index=16, multi-index=[1,1,0,0,0]
index=17, multi-index=[1,0,1,0,0]
index=18, multi-index=[1,0,0,1,0]
index=19, multi-index=[1,0,0,0,1]
index=20, multi-index=[0,1,1,0,0]
index=21, multi-index=[0,1,0,1,0]
index=22, multi-index=[0,1,0,0,1]
index=23, multi-index=[0,0,1,1,0]
index=24, multi-index=[0,0,1,0,1]
12 changes: 12 additions & 0 deletions lib/test/t_LinearEnumerateFunction_std.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ int main(int, char *[])
fullprint << "index=" << index << ", multi-index=" << multiIndex << ", linear index=" << f.inverse(multiIndex) << std::endl;
}
}
// Test getMarginal()
fullprint << "Test getMarginal()" << std::endl;
LinearEnumerateFunction enumerateFunction(10);
Indices indices({0, 2, 4, 6, 9});
EnumerateFunction marginalEnumerate(enumerateFunction.getMarginal(indices));
assert_equal(marginalEnumerate.getDimension(), indices.getSize());
for (UnsignedInteger index = 0; index < size; ++index)
{
Indices multiIndex(marginalEnumerate(index));
fullprint << "index=" << index << ", multi-index=" << multiIndex << std::endl;
}

}
catch (TestFailed & ex)
{
Expand Down
11 changes: 11 additions & 0 deletions lib/test/t_LinearEnumerateFunction_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,14 @@ index=6, multi-index=[1,0,1], linear index=6
index=7, multi-index=[0,2,0], linear index=7
index=8, multi-index=[0,1,1], linear index=8
index=9, multi-index=[0,0,2], linear index=9
Test getMarginal()
index=0, multi-index=[0,0,0,0,0]
index=1, multi-index=[1,0,0,0,0]
index=2, multi-index=[0,1,0,0,0]
index=3, multi-index=[0,0,1,0,0]
index=4, multi-index=[0,0,0,1,0]
index=5, multi-index=[0,0,0,0,1]
index=6, multi-index=[2,0,0,0,0]
index=7, multi-index=[1,1,0,0,0]
index=8, multi-index=[1,0,1,0,0]
index=9, multi-index=[1,0,0,1,0]
18 changes: 18 additions & 0 deletions python/src/EnumerateFunctionImplementation_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,21 @@ ub : sequence of int
%enddef
%feature("docstring") OT::EnumerateFunctionImplementation::getUpperBound
OT_EnumerateFunction_getUpperBound_doc

// ---------------------------------------------------------------------

%define OT_EnumerateFunction_getMarginal_doc
"Get the marginal enumerate function.

Parameters
----------
indices : int or sequence of int, :math:`0 \leq i < n`
List of marginal indices.

Returns
-------
enumerateFunction : :class:`~openturns.EnumerateFunction`
The marginal enumerate function."
%enddef
%feature("docstring") OT::EnumerateFunctionImplementation::getMarginal
OT_EnumerateFunction_getMarginal_doc
2 changes: 2 additions & 0 deletions python/src/EnumerateFunction_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ OT_EnumerateFunction_setDimension_doc
OT_EnumerateFunction_setUpperBound_doc
%feature("docstring") OT::EnumerateFunction::getUpperBound
OT_EnumerateFunction_getUpperBound_doc
%feature("docstring") OT::EnumerateFunction::getMarginal
OT_EnumerateFunction_getMarginal_doc
1 change: 1 addition & 0 deletions python/src/HyperbolicAnisotropicEnumerateFunction_doc.i.in
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ Parameters
----------
w : sequence of float
Weights of the indices in each dimension."

22 changes: 22 additions & 0 deletions python/test/t_HyperbolicAnisotropicEnumerateFunction_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -412,3 +412,25 @@ index= 21 [0,6]
index= 22 [3,4]
index= 23 [2,5]
index= 24 [1,6]
Test getMarginal() from indices
index= 0 [0,0,0] 0
index= 1 [1,0,0] 1
index= 2 [0,1,0] 2
index= 3 [0,0,1] 3
index= 4 [2,0,0] 4
index= 5 [0,2,0] 5
index= 6 [0,0,2] 6
index= 7 [3,0,0] 7
index= 8 [0,3,0] 8
index= 9 [0,0,3] 9
Test getMarginal() from a single integer
index= 0 [0] 0
index= 1 [1] 1
index= 2 [2] 2
index= 3 [3] 3
index= 4 [4] 4
index= 5 [5] 5
index= 6 [6] 6
index= 7 [7] 7
index= 8 [8] 8
index= 9 [9] 9
16 changes: 16 additions & 0 deletions python/test/t_HyperbolicAnisotropicEnumerateFunction_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,19 @@
print("index=", index, repr(m))
assert m[0] <= 3, "wrong bound"
assert index == index_inv, "wrong inverse"
#
print("Test getMarginal() from indices")
f = ot.HyperbolicAnisotropicEnumerateFunction(5, 0.5)
marginalf = f.getMarginal([0, 3, 4])
for index in range(10):
m = marginalf(index)
index_inv = marginalf.inverse(m)
print("index=", index, repr(m), index_inv)

print("Test getMarginal() from a single integer")
f = ot.HyperbolicAnisotropicEnumerateFunction(5, 0.5)
marginalf = f.getMarginal(3)
for index in range(10):
m = marginalf(index)
index_inv = marginalf.inverse(m)
print("index=", index, repr(m), index_inv)
22 changes: 22 additions & 0 deletions python/test/t_LinearEnumerateFunction_std.expout
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,25 @@ degree 8 max_degree_strata_index 8
degree 8 size 165
degree 9 max_degree_strata_index 9
degree 9 size 220
Test getMarginal() from Indices
index= 0 [0,0,0] 0
index= 1 [1,0,0] 1
index= 2 [0,1,0] 2
index= 3 [0,0,1] 3
index= 4 [2,0,0] 4
index= 5 [1,1,0] 5
index= 6 [1,0,1] 6
index= 7 [0,2,0] 7
index= 8 [0,1,1] 8
index= 9 [0,0,2] 9
Test getMarginal() from a single integer
index= 0 [0] 0
index= 1 [1] 1
index= 2 [2] 2
index= 3 [3] 3
index= 4 [4] 4
index= 5 [5] 5
index= 6 [6] 6
index= 7 [7] 7
index= 8 [8] 8
index= 9 [9] 9
16 changes: 16 additions & 0 deletions python/test/t_LinearEnumerateFunction_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,19 @@
print("degree", d, "max_degree_strata_index", idx)
size = f.getBasisSizeFromTotalDegree(d)
print("degree", d, "size", size)
#
print("Test getMarginal() from Indices")
f = ot.LinearEnumerateFunction(5)
marginalf = f.getMarginal([0, 3, 4])
for index in range(10):
m = marginalf(index)
index_inv = marginalf.inverse(m)
print("index=", index, repr(m), index_inv)

print("Test getMarginal() from a single integer")
f = ot.LinearEnumerateFunction(5)
marginalf = f.getMarginal(3)
for index in range(10):
m = marginalf(index)
index_inv = marginalf.inverse(m)
print("index=", index, repr(m), index_inv)
Loading

0 comments on commit f318729

Please sign in to comment.