Skip to content

Commit

Permalink
refactor: add Type to interpolator class and factory method for usa…
Browse files Browse the repository at this point in the history
…bility (#110)

* refactor: make interpolator class composable with interpolationtype for usability

* feat: add generateinterpolate method

* feat: add bindings and binding tets

* fix: remove redundant tests based on vishwa's feedback

Co-authored-by: Vishwa Shah <vishwa2710@gmail.com>

* fix: address rest of comments

* docs: fix interpolator constructor

---------

Co-authored-by: Vishwa Shah <vishwa2710@gmail.com>
  • Loading branch information
apaletta3 and vishwa2710 authored Feb 8, 2024
1 parent 1da0c43 commit 9830b0b
Show file tree
Hide file tree
Showing 23 changed files with 458 additions and 225 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ IF (BUILD_UNIT_TESTS)

# TARGET_LINK_LIBRARIES (${UNIT_TESTS_TARGET} "GTest::GTest" "GTest::Main")
TARGET_LINK_LIBRARIES (${UNIT_TESTS_TARGET} "${GTEST_BOTH_LIBRARIES}")
TARGET_LINK_LIBRARIES (${UNIT_TESTS_TARGET} "gmock")
TARGET_LINK_LIBRARIES (${UNIT_TESTS_TARGET} "${SHARED_LIBRARY_TARGET}")

GTEST_DISCOVER_TESTS (${UNIT_TESTS_TARGET})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,64 @@
/// Apache License 2.0

#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator.hpp>

#include <OpenSpaceToolkitMathematicsPy/CurveFitting/Interpolator/BarycentricRational.cpp>
#include <OpenSpaceToolkitMathematicsPy/CurveFitting/Interpolator/CubicSpline.cpp>
#include <OpenSpaceToolkitMathematicsPy/CurveFitting/Interpolator/Linear.cpp>

using namespace pybind11;

using ostk::core::type::Shared;

using ostk::mathematics::curvefitting::Interpolator;
using ostk::mathematics::object::VectorXd;

// Trampoline class for virtual member functions
class PyInterpolator : public Interpolator
{
public:
using Interpolator::Interpolator;

// Trampoline (need one for each virtual function)

VectorXd evaluate(const VectorXd& aQueryVector) const override
{
PYBIND11_OVERRIDE_PURE(VectorXd, Interpolator, evaluate, aQueryVector);
}

double evaluate(const double& aQueryValue) const override
{
PYBIND11_OVERRIDE_PURE(double, Interpolator, evaluate, aQueryValue);
}
};

inline void OpenSpaceToolkitMathematicsPy_CurveFitting_Interpolator(pybind11::module& aModule)
{
class_<Interpolator, PyInterpolator, Shared<Interpolator>> interpolator_class(aModule, "Interpolator");

enum_<Interpolator::Type>(interpolator_class, "Type")

.value("BarycentricRational", Interpolator::Type::BarycentricRational)
.value("CubicSpline", Interpolator::Type::CubicSpline)
.value("Linear", Interpolator::Type::Linear)

;

interpolator_class

.def(init<const Interpolator::Type&>(), arg("interpolation_type"))

.def("get_interpolation_type", &Interpolator::getInterpolationType)

.def("evaluate", overload_cast<const VectorXd&>(&Interpolator::evaluate, const_), arg("x"))
.def("evaluate", overload_cast<const double&>(&Interpolator::evaluate, const_), arg("x"))

.def_static(
"generate_interpolator", &Interpolator::GenerateInterpolator, arg("interpolation_type"), arg("x"), arg("y")
)

;

// Create "interpolator" python submodule
auto interpolator = aModule.def_submodule("interpolator");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ inline void OpenSpaceToolkitMathematicsPy_CurveFitting_Interpolator_BarycentricR
{
using namespace pybind11;

using ostk::core::type::Shared;

using ostk::mathematics::curvefitting::Interpolator;
using ostk::mathematics::object::VectorXd;

using ostk::mathematics::curvefitting::interpolator::BarycentricRational;

// noncopyable class with Boost, removed in Pybind11
class_<BarycentricRational>(aModule, "BarycentricRational")
class_<BarycentricRational, Interpolator, Shared<BarycentricRational>>(aModule, "BarycentricRational")

.def(init<const VectorXd&, const VectorXd&>(), arg("x"), arg("y"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ inline void OpenSpaceToolkitMathematicsPy_CurveFitting_Interpolator_CubicSpline(
using namespace pybind11;

using ostk::core::type::Real;
using ostk::core::type::Shared;

using ostk::mathematics::curvefitting::Interpolator;
using ostk::mathematics::object::VectorXd;

using ostk::mathematics::curvefitting::interpolator::CubicSpline;

// noncopyable class with Boost, removed in Pybind11
class_<CubicSpline>(aModule, "CubicSpline")
class_<CubicSpline, Interpolator, Shared<CubicSpline>>(aModule, "CubicSpline")

.def(init<const VectorXd&, const VectorXd&>(), arg("x"), arg("y"))
.def(init<const VectorXd&, const Real&, const Real&>(), arg("y"), arg("x_0"), arg("h"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ inline void OpenSpaceToolkitMathematicsPy_CurveFitting_Interpolator_Linear(pybin
{
using namespace pybind11;

using ostk::core::type::Shared;

using ostk::mathematics::curvefitting::Interpolator;
using ostk::mathematics::object::VectorXd;

using ostk::mathematics::curvefitting::interpolator::Linear;

// noncopyable class with Boost, removed in Pybind11
class_<Linear>(aModule, "Linear")
class_<Linear, Interpolator, Shared<Linear>>(aModule, "Linear")

.def(init<const VectorXd&, const VectorXd&>(), arg("x"), arg("y"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@
#include <OpenSpaceToolkit/Mathematics/Geometry/2D/Object.hpp>

#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Composite.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/MultiPolygon.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Polygon.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/LineString.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Segment.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Line.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/LineString.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/MultiPolygon.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Point.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/PointSet.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Polygon.cpp>
#include <OpenSpaceToolkitMathematicsPy/Geometry/2D/Object/Segment.cpp>

inline void OpenSpaceToolkitMathematicsPy_Geometry_2D_Object(pybind11::module &aModule)
{
using namespace pybind11;

using ostk::mathematics::geometry::d2::Object;

// noncopyable class with Boost, removed in Pybind11
class_<Object> ob(aModule, "Object");

ob
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
# Apache License 2.0

import pytest

from ostk.mathematics.curve_fitting import Interpolator
from ostk.mathematics.curve_fitting.interpolator import BarycentricRational


@pytest.fixture
def interpolator() -> BarycentricRational:
return BarycentricRational(
x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0]
)


class TestBarycentricRational:
def test_default_constructor(self):
BarycentricRational(
x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0]
)

def test_evaluate(self):
spline = BarycentricRational(
x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0]
)

assert spline.evaluate(0.0) == 0.0
assert spline.evaluate(1.0) == 3.0
assert spline.evaluate(2.0) == 6.0
assert spline.evaluate(4.0) == 9.0
assert spline.evaluate(5.0) == 17.0
assert spline.evaluate(6.0) == 5.0
def test_constructors(self, interpolator: BarycentricRational):
assert interpolator is not None
assert isinstance(interpolator, Interpolator)
assert isinstance(interpolator, BarycentricRational)

def test_evaluate(self, interpolator: BarycentricRational):

assert interpolator.evaluate(0.0) == 0.0
assert interpolator.evaluate(1.0) == 3.0
assert interpolator.evaluate(2.0) == 6.0
assert interpolator.evaluate(4.0) == 9.0
assert interpolator.evaluate(5.0) == 17.0
assert interpolator.evaluate(6.0) == 5.0

assert (
spline.evaluate(x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0])
interpolator.evaluate(x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0])
== [0.0, 3.0, 6.0, 9.0, 17.0, 5.0]
).all()
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@

import numpy as np

from ostk.mathematics.curve_fitting import Interpolator
from ostk.mathematics.curve_fitting.interpolator import CubicSpline


@pytest.fixture
def interpolator() -> CubicSpline:
return CubicSpline(
x=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0]
)


class TestCubicSpline:
def test_default_constructor(self):
CubicSpline(x=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0])
def test_constructors(self, interpolator: CubicSpline):
assert interpolator is not None
assert isinstance(interpolator, Interpolator)
assert isinstance(interpolator, CubicSpline)

def test_default_constructor_2(self):
CubicSpline(y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0], x_0=0.0, h=1.0)
Expand All @@ -27,9 +37,9 @@ def test_evaluate(self):
-4.671290719512360170e06,
-4.673730319758670405e06,
]
spline = CubicSpline(y, 0.0, 10.0)
interpolator = CubicSpline(y, 0.0, 10.0)

for i in range(10):
assert pytest.approx(spline.evaluate(i * 10.0)) == y[i]
assert pytest.approx(interpolator.evaluate(i * 10.0)) == y[i]

assert pytest.approx(spline.evaluate(np.linspace(0.0, 90.0, 10))) == y
assert pytest.approx(interpolator.evaluate(np.linspace(0.0, 90.0, 10))) == y
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Apache License 2.0

import pytest

from ostk.mathematics.curve_fitting import Interpolator


@pytest.fixture
def interpolation_type() -> Interpolator.Type:
return Interpolator.Type.Linear


@pytest.fixture
def interpolator(interpolation_type: Interpolator.Type) -> Interpolator:
class MyInterpolator(Interpolator):
def evaluate(self, x: list[float]) -> list[float]:
return x

def evaluate(self, x: float) -> float:
return x

return MyInterpolator(interpolation_type=interpolation_type)


class TestInterpolator:
def test_subclass(self, interpolator: Interpolator):
assert interpolator is not None
assert isinstance(interpolator, Interpolator)

def test_get_interpolation_type(
self,
interpolator: Interpolator,
interpolation_type: Interpolator.Type,
):
assert interpolator.get_interpolation_type() == interpolation_type

@pytest.mark.parametrize(
"parametrized_interpolation_type, x, y",
[
(
Interpolator.Type.BarycentricRational,
[0.0, 1.0, 2.0, 4.0, 5.0, 6.0],
[0.0, 3.0, 6.0, 9.0, 17.0, 5.0],
),
(
Interpolator.Type.CubicSpline,
[0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
[0.0, 3.0, 6.0, 9.0, 17.0, 5.0],
),
(
Interpolator.Type.Linear,
[0.0, 1.0, 2.0, 4.0, 5.0, 6.0],
[0.0, 3.0, 6.0, 9.0, 17.0, 5.0],
),
],
)
def test_generate_interpolators(
self,
parametrized_interpolation_type: Interpolator.Type,
x: list[float],
y: list[float],
):
interpolator: Interpolator = Interpolator.generate_interpolator(
interpolation_type=parametrized_interpolation_type,
x=x,
y=y,
)

assert interpolator is not None
assert isinstance(interpolator, Interpolator)
assert interpolator.get_interpolation_type() == parametrized_interpolation_type
15 changes: 13 additions & 2 deletions bindings/python/test/curve_fitting/interpolator/test_linear.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# Apache License 2.0

import pytest

from ostk.mathematics.curve_fitting import Interpolator
from ostk.mathematics.curve_fitting.interpolator import Linear


@pytest.fixture
def interpolator() -> Linear:
return Linear(x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0])


class TestLinear:
def test_default_constructor(self):
Linear(x=[0.0, 1.0, 2.0, 4.0, 5.0, 6.0], y=[0.0, 3.0, 6.0, 9.0, 17.0, 5.0])
def test_constructors(self, interpolator: Linear):
assert interpolator is not None
assert isinstance(interpolator, Interpolator)
assert isinstance(interpolator, Linear)


def test_evaluate(self):
interpolator = Linear(
Expand Down
Loading

0 comments on commit 9830b0b

Please sign in to comment.