From 6f0159eda625df872988191467ebdfe62cfccf71 Mon Sep 17 00:00:00 2001 From: Antoine Paletta <98616558+apaletta3@users.noreply.github.com> Date: Wed, 14 Feb 2024 21:32:28 +0100 Subject: [PATCH] refactor: tabulated trajectory to use Interpolator::Type (#349) * refactor: tabulated trajectory to use Interpolator::Type * fix: apply vishwa's suggestions Co-authored-by: Vishwa Shah --------- Co-authored-by: Vishwa Shah --- .../Trajectory/Orbit/Model/Tabulated.cpp | 27 ++++-------- .../trajectory/orbit/models/test_tabulated.py | 42 +++++++++---------- .../Trajectory/Model/Tabulated.hpp | 26 +++--------- .../Trajectory/Orbit/Model/Tabulated.hpp | 6 ++- .../Trajectory/Model/Tabulated.cpp | 34 ++++----------- .../Astrodynamics/Trajectory/Orbit.cpp | 7 ++-- .../Trajectory/Orbit/Model/Tabulated.cpp | 2 +- .../Trajectory/Orbit/Model/Tabulated.test.cpp | 36 +++++++++------- 8 files changed, 70 insertions(+), 110 deletions(-) diff --git a/bindings/python/src/OpenSpaceToolkitAstrodynamicsPy/Trajectory/Orbit/Model/Tabulated.cpp b/bindings/python/src/OpenSpaceToolkitAstrodynamicsPy/Trajectory/Orbit/Model/Tabulated.cpp index 4b06afe5a..a533caeb9 100644 --- a/bindings/python/src/OpenSpaceToolkitAstrodynamicsPy/Trajectory/Orbit/Model/Tabulated.cpp +++ b/bindings/python/src/OpenSpaceToolkitAstrodynamicsPy/Trajectory/Orbit/Model/Tabulated.cpp @@ -7,49 +7,36 @@ using namespace pybind11; using ostk::core::container::Array; using ostk::core::type::Integer; +using ostk::mathematics::curvefitting::Interpolator; + using ostk::astrodynamics::trajectory::State; using ostk::astrodynamics::trajectory::orbit::model::Tabulated; inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Orbit_Model_Tabulated(pybind11::module& aModule) { - class_ tabulated_class( + class_( aModule, "Tabulated", R"doc( Tabulated orbit model. )doc" - ); - - enum_( - tabulated_class, - "InterpolationType", - R"doc( - The Interpolation Type. - )doc" ) - .value("Linear", Tabulated::InterpolationType::Linear, "Linear") - .value("CubicSpline", Tabulated::InterpolationType::CubicSpline, "Cubic Spline") - .value("BarycentricRational", Tabulated::InterpolationType::BarycentricRational, "Barycentric Rational") - - ; - - tabulated_class .def( - init, Integer, Tabulated::InterpolationType>(), + init, Integer, Interpolator::Type>(), R"doc( Constructor. Args: states (list[State]): The states. initial_revolution_number (int): The initial revolution number. - interpolation_type (Tabulated.InterpolationType, optional): The interpolation type. + interpolation_type (Interpolator.Type, optional): The interpolation type. )doc", arg("states"), arg("initial_revolution_number"), - arg("interpolation_type") = DEFAULT_TABULATED_INTERPOLATION_TYPE + arg("interpolation_type") = DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE ) .def(self == self) @@ -115,7 +102,7 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Orbit_Model_Tabulated(pyb Get the interpolation type of the `Tabulated` model. Returns: - Tabulated.InterpolationType: The interpolation type. + Interpolator.Type: The interpolation type. )doc" ) diff --git a/bindings/python/test/trajectory/orbit/models/test_tabulated.py b/bindings/python/test/trajectory/orbit/models/test_tabulated.py index c12240fb1..4837eb2d2 100644 --- a/bindings/python/test/trajectory/orbit/models/test_tabulated.py +++ b/bindings/python/test/trajectory/orbit/models/test_tabulated.py @@ -4,6 +4,8 @@ import numpy as np +from ostk.mathematics.curve_fitting import Interpolator + from ostk.physics.time import Instant from ostk.physics.time import DateTime from ostk.physics.time import Scale @@ -230,13 +232,13 @@ class TestTabulated: @pytest.mark.parametrize( "interpolation_type", ( - (Tabulated.InterpolationType.Linear), - (Tabulated.InterpolationType.CubicSpline), - (Tabulated.InterpolationType.BarycentricRational), + (Interpolator.Type.Linear), + (Interpolator.Type.CubicSpline), + (Interpolator.Type.BarycentricRational), ), ) def test_constructor( - self, test_states: list[State], interpolation_type: Tabulated.InterpolationType + self, test_states: list[State], interpolation_type: Interpolator.Type ): assert ( Tabulated( @@ -255,7 +257,7 @@ def test_constructor_orbit_tabulated_sucess( tabulated = Tabulated( states=test_states, initial_revolution_number=1, - interpolation_type=Tabulated.InterpolationType.CubicSpline, + interpolation_type=Interpolator.Type.CubicSpline, ) orbit: Orbit = Orbit(tabulated, earth) @@ -267,26 +269,24 @@ def test_get_interpolation_type(self, test_states: list[State]): tabulated = Tabulated( states=test_states, initial_revolution_number=1, - interpolation_type=Tabulated.InterpolationType.CubicSpline, + interpolation_type=Interpolator.Type.CubicSpline, ) - assert ( - tabulated.get_interpolation_type() == Tabulated.InterpolationType.CubicSpline - ) + assert tabulated.get_interpolation_type() == Interpolator.Type.CubicSpline @pytest.mark.parametrize( "interpolation_type,error_tolerance", ( - (Tabulated.InterpolationType.Linear, 420.0), - (Tabulated.InterpolationType.CubicSpline, 5e-3), - (Tabulated.InterpolationType.BarycentricRational, 5e-2), + (Interpolator.Type.Linear, 420.0), + (Interpolator.Type.CubicSpline, 5e-3), + (Interpolator.Type.BarycentricRational, 5e-2), ), ) def test_calculate_state_at_success( self, test_states: list[State], reference_states: list[State], - interpolation_type: Tabulated.InterpolationType, + interpolation_type: Interpolator.Type, error_tolerance: float, ): tabulated = Tabulated( @@ -314,16 +314,16 @@ def test_calculate_state_at_success( @pytest.mark.parametrize( "interpolation_type,error_tolerance", ( - (Tabulated.InterpolationType.Linear, 420.0), - (Tabulated.InterpolationType.CubicSpline, 5e-3), - (Tabulated.InterpolationType.BarycentricRational, 5e-2), + (Interpolator.Type.Linear, 420.0), + (Interpolator.Type.CubicSpline, 5e-3), + (Interpolator.Type.BarycentricRational, 5e-2), ), ) def test_calculate_states_at_success( self, test_states: list[State], reference_states: list[State], - interpolation_type: Tabulated.InterpolationType, + interpolation_type: Interpolator.Type, error_tolerance: float, ): tabulated = Tabulated( @@ -355,13 +355,13 @@ def test_calculate_states_at_success( @pytest.mark.parametrize( "interpolation_type", ( - (Tabulated.InterpolationType.Linear), - (Tabulated.InterpolationType.CubicSpline), - (Tabulated.InterpolationType.BarycentricRational), + (Interpolator.Type.Linear), + (Interpolator.Type.CubicSpline), + (Interpolator.Type.BarycentricRational), ), ) def test_calculate_state_at_failure( - self, test_states: list[State], interpolation_type: Tabulated.InterpolationType + self, test_states: list[State], interpolation_type: Interpolator.Type ): tabulated = Tabulated( states=test_states, diff --git a/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.hpp b/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.hpp index 0fd182beb..6bc8be574 100644 --- a/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.hpp +++ b/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.hpp @@ -10,9 +10,7 @@ #include #include -#include -#include -#include +#include #include #include @@ -37,7 +35,7 @@ using ostk::core::type::Index; using ostk::core::type::Shared; using ostk::core::type::Size; -using ostk::mathematics::curvefitting::interpolator::Interpolator; +using ostk::mathematics::curvefitting::Interpolator; using ostk::mathematics::object::MatrixXd; using ostk::mathematics::object::VectorXd; @@ -48,7 +46,7 @@ using ostk::physics::time::Scale; using ostk::astrodynamics::trajectory::Model; using ostk::astrodynamics::trajectory::State; -#define DEFAULT_TABULATED_INTERPOLATION_TYPE Tabulated::InterpolationType::Linear +#define DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE Interpolator::Type::Linear /// @brief Tabulated trajectory model /// @@ -58,18 +56,9 @@ using ostk::astrodynamics::trajectory::State; class Tabulated : public virtual Model { public: - enum class InterpolationType - { - - Linear, - BarycentricRational, - CubicSpline - - }; - Tabulated( const Array& aStateArray, - const InterpolationType& anInterpolationType = DEFAULT_TABULATED_INTERPOLATION_TYPE + const Interpolator::Type& anInterpolationType = DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE ); virtual Tabulated* clone() const override; @@ -84,7 +73,7 @@ class Tabulated : public virtual Model Interval getInterval() const; - InterpolationType getInterpolationType() const; + Interpolator::Type getInterpolationType() const; State getFirstState() const; @@ -106,10 +95,7 @@ class Tabulated : public virtual Model private: State firstState_ = State::Undefined(); State lastState_ = State::Undefined(); - - InterpolationType interpolationType_; - - Array> interpolators_ = Array>::Empty(); + Array> interpolators_; }; } // namespace model diff --git a/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.hpp b/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.hpp index 4b8e2dafd..608db6f4b 100644 --- a/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.hpp +++ b/include/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.hpp @@ -6,6 +6,8 @@ #include #include +#include + #include #include @@ -27,6 +29,8 @@ namespace model using ostk::core::container::Array; using ostk::core::type::Integer; +using ostk::mathematics::curvefitting::Interpolator; + using ostk::physics::time::Instant; using ostk::astrodynamics::trajectory::State; @@ -37,7 +41,7 @@ class Tabulated : public virtual trajectory::orbit::Model, public trajectory::mo Tabulated( const Array& aStateArray, const Integer& anInitialRevolutionNumber, - const InterpolationType& aType = DEFAULT_TABULATED_INTERPOLATION_TYPE + const Interpolator::Type& aType = DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE ); virtual Tabulated* clone() const override; diff --git a/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.cpp b/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.cpp index 37801553b..7bfcc3ebb 100644 --- a/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.cpp +++ b/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.cpp @@ -15,14 +15,9 @@ namespace trajectory namespace model { -Tabulated::Tabulated(const Array& aStateArray, const InterpolationType& anInterpolationType) - : Model(), - interpolationType_(anInterpolationType) +Tabulated::Tabulated(const Array& aStateArray, const Interpolator::Type& anInterpolationType) + : Model() { - using ostk::mathematics::curvefitting::interpolator::BarycentricRational; - using ostk::mathematics::curvefitting::interpolator::CubicSpline; - using ostk::mathematics::curvefitting::interpolator::Linear; - if (aStateArray.getSize() < 2) { return; @@ -56,23 +51,7 @@ Tabulated::Tabulated(const Array& aStateArray, const InterpolationType& a for (Index i = 0; i < Size(coordinates.cols()); ++i) { - if (interpolationType_ == Tabulated::InterpolationType::CubicSpline) - { - interpolators_.add(std::make_shared(CubicSpline(timestamps, coordinates.col(i)))); - } - else if (interpolationType_ == Tabulated::InterpolationType::BarycentricRational) - { - interpolators_.add(std::make_shared(BarycentricRational(timestamps, coordinates.col(i)) - )); - } - else if (interpolationType_ == Tabulated::InterpolationType::Linear) - { - interpolators_.add(std::make_shared(Linear(timestamps, coordinates.col(i)))); - } - else - { - throw ostk::core::error::runtime::Wrong("InterpolationType"); - } + interpolators_.add(Interpolator::GenerateInterpolator(anInterpolationType, timestamps, coordinates.col(i))); } } @@ -88,7 +67,7 @@ bool Tabulated::operator==(const Tabulated& aTabulatedModel) const return false; } - return interpolationType_ == aTabulatedModel.getInterpolationType() && + return this->getInterpolationType() == aTabulatedModel.getInterpolationType() && firstState_ == aTabulatedModel.getFirstState() && lastState_ == aTabulatedModel.getLastState(); } @@ -119,14 +98,15 @@ Interval Tabulated::getInterval() const return Interval::Closed(firstState_.accessInstant(), lastState_.accessInstant()); } -Tabulated::InterpolationType Tabulated::getInterpolationType() const +Interpolator::Type Tabulated::getInterpolationType() const { if (!this->isDefined()) { throw ostk::core::error::runtime::Undefined("Tabulated"); } - return interpolationType_; + // Since all interpolators are of the same type, we can just return the type of the first one. + return interpolators_[0]->getInterpolationType(); } State Tabulated::getFirstState() const diff --git a/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp b/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp index 913a4b50f..0c7094088 100644 --- a/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp +++ b/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp @@ -1,10 +1,9 @@ /// Apache License 2.0 -#include - #include #include +#include #include #include @@ -36,6 +35,7 @@ using ostk::core::type::Uint8; using ostk::core::type::Real; using ostk::core::type::Index; +using ostk::mathematics::curvefitting::Interpolator; using ostk::mathematics::object::Vector3d; using ostk::physics::time::Duration; @@ -1053,8 +1053,7 @@ Array> Orbit::ComputePasses(const Array& aStateArray, c } } - const model::Tabulated tabulated = - model::Tabulated(aStateArray, model::Tabulated::InterpolationType::BarycentricRational); + const model::Tabulated tabulated = model::Tabulated(aStateArray, Interpolator::Type::BarycentricRational); const Instant& epoch = aStateArray.accessFirst().accessInstant(); diff --git a/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.cpp b/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.cpp index 56ada33fb..880631772 100644 --- a/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.cpp +++ b/src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.cpp @@ -19,7 +19,7 @@ namespace model Tabulated::Tabulated( const Array& aStateArray, const Integer& anInitialRevolutionNumber, - const InterpolationType& anInterpolationType + const Interpolator::Type& anInterpolationType ) : trajectory::orbit::Model(), trajectory::model::Tabulated(aStateArray, anInterpolationType), diff --git a/test/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.test.cpp b/test/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.test.cpp index 7c82de8ef..6d8bcc22e 100644 --- a/test/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.test.cpp +++ b/test/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Model/Tabulated.test.cpp @@ -3,6 +3,9 @@ #include #include +#include +#include + #include #include @@ -22,6 +25,7 @@ using ostk::core::container::Tuple; using ostk::core::filesystem::Path; using ostk::core::filesystem::File; +using ostk::mathematics::curvefitting::Interpolator; using ostk::mathematics::object::VectorXd; using ostk::physics::Environment; @@ -101,7 +105,7 @@ class OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated : public : TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, Constructor) { - const Tabulated tabulated(states_, 0, Tabulated::InterpolationType::Linear); + const Tabulated tabulated(states_, 0, Interpolator::Type::Linear); Environment environment = Environment::Default(); @@ -112,7 +116,7 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, GetInter { using ostk::physics::time::Interval; - const Tabulated tabulated(states_, 0, Tabulated::InterpolationType::Linear); + const Tabulated tabulated(states_, 0, Interpolator::Type::Linear); EXPECT_TRUE(tabulated.getInterval().isDefined()); EXPECT_TRUE( @@ -123,16 +127,16 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, GetInter TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, EqualToOperator) { - const Tabulated tabulated(states_, 0, Tabulated::InterpolationType::Linear); - const Tabulated anotherTabulated(states_, 0, Tabulated::InterpolationType::Linear); + const Tabulated tabulated(states_, 0, Interpolator::Type::Linear); + const Tabulated anotherTabulated(states_, 0, Interpolator::Type::Linear); EXPECT_TRUE(tabulated == anotherTabulated); } TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, NotEqualToOperator) { - const Tabulated tabulated(states_, 0, Tabulated::InterpolationType::CubicSpline); - const Tabulated anotherTabulated(states_, 0, Tabulated::InterpolationType::Linear); + const Tabulated tabulated(states_, 0, Interpolator::Type::CubicSpline); + const Tabulated anotherTabulated(states_, 0, Interpolator::Type::Linear); EXPECT_TRUE(tabulated != anotherTabulated); } @@ -141,15 +145,15 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, Calculat { loadData(); - const Array> testCases = { - {Tabulated::InterpolationType::Linear, 420.0}, - {Tabulated::InterpolationType::BarycentricRational, 5e-2}, - {Tabulated::InterpolationType::CubicSpline, 5e-3}, + const Array> testCases = { + {Interpolator::Type::Linear, 420.0}, + {Interpolator::Type::BarycentricRational, 5e-2}, + {Interpolator::Type::CubicSpline, 5e-3}, }; for (const auto& testCase : testCases) { - Tabulated::InterpolationType interpolationType = std::get<0>(testCase); + Interpolator::Type interpolationType = std::get<0>(testCase); Real tolerance = std::get<1>(testCase); const Tabulated tabulated(states_, 0, interpolationType); @@ -176,15 +180,15 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Model_Tabulated, Calculat { loadData(); - const Array> testCases = { - {Tabulated::InterpolationType::Linear, 420.0}, - {Tabulated::InterpolationType::BarycentricRational, 5e-2}, - {Tabulated::InterpolationType::CubicSpline, 5e-3}, + const Array> testCases = { + {Interpolator::Type::Linear, 420.0}, + {Interpolator::Type::BarycentricRational, 5e-2}, + {Interpolator::Type::CubicSpline, 5e-3}, }; for (const auto& testCase : testCases) { - Tabulated::InterpolationType interpolationType = std::get<0>(testCase); + Interpolator::Type interpolationType = std::get<0>(testCase); Real tolerance = std::get<1>(testCase); const Tabulated tabulated(states_, 0, interpolationType);