Skip to content

Commit

Permalink
refactor: tabulated trajectory to use Interpolator::Type (#349)
Browse files Browse the repository at this point in the history
* refactor: tabulated trajectory to use Interpolator::Type

* fix: apply vishwa's suggestions

Co-authored-by: Vishwa Shah <vishwa2710@gmail.com>

---------

Co-authored-by: Vishwa Shah <vishwa2710@gmail.com>
  • Loading branch information
apaletta3 and vishwa2710 authored Feb 14, 2024
1 parent 7412071 commit 6f0159e
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,36 @@ using namespace pybind11;
using ostk::core::container::Array;
using ostk::core::type::Integer;

using ostk::mathematics::curvefitting::Interpolator;

using ostk::astrodynamics::trajectory::State;
using ostk::astrodynamics::trajectory::orbit::model::Tabulated;

inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Orbit_Model_Tabulated(pybind11::module& aModule)
{
class_<Tabulated, ostk::astrodynamics::trajectory::orbit::Model> tabulated_class(
class_<Tabulated, ostk::astrodynamics::trajectory::orbit::Model>(
aModule,
"Tabulated",
R"doc(
Tabulated orbit model.
)doc"
);

enum_<Tabulated::InterpolationType>(
tabulated_class,
"InterpolationType",
R"doc(
The Interpolation Type.
)doc"
)
.value("Linear", Tabulated::InterpolationType::Linear, "Linear")
.value("CubicSpline", Tabulated::InterpolationType::CubicSpline, "Cubic Spline")
.value("BarycentricRational", Tabulated::InterpolationType::BarycentricRational, "Barycentric Rational")

;

tabulated_class

.def(
init<Array<State>, Integer, Tabulated::InterpolationType>(),
init<Array<State>, Integer, Interpolator::Type>(),
R"doc(
Constructor.
Args:
states (list[State]): The states.
initial_revolution_number (int): The initial revolution number.
interpolation_type (Tabulated.InterpolationType, optional): The interpolation type.
interpolation_type (Interpolator.Type, optional): The interpolation type.
)doc",
arg("states"),
arg("initial_revolution_number"),
arg("interpolation_type") = DEFAULT_TABULATED_INTERPOLATION_TYPE
arg("interpolation_type") = DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE
)

.def(self == self)
Expand Down Expand Up @@ -115,7 +102,7 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Orbit_Model_Tabulated(pyb
Get the interpolation type of the `Tabulated` model.
Returns:
Tabulated.InterpolationType: The interpolation type.
Interpolator.Type: The interpolation type.
)doc"
)
Expand Down
42 changes: 21 additions & 21 deletions bindings/python/test/trajectory/orbit/models/test_tabulated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import numpy as np

from ostk.mathematics.curve_fitting import Interpolator

from ostk.physics.time import Instant
from ostk.physics.time import DateTime
from ostk.physics.time import Scale
Expand Down Expand Up @@ -230,13 +232,13 @@ class TestTabulated:
@pytest.mark.parametrize(
"interpolation_type",
(
(Tabulated.InterpolationType.Linear),
(Tabulated.InterpolationType.CubicSpline),
(Tabulated.InterpolationType.BarycentricRational),
(Interpolator.Type.Linear),
(Interpolator.Type.CubicSpline),
(Interpolator.Type.BarycentricRational),
),
)
def test_constructor(
self, test_states: list[State], interpolation_type: Tabulated.InterpolationType
self, test_states: list[State], interpolation_type: Interpolator.Type
):
assert (
Tabulated(
Expand All @@ -255,7 +257,7 @@ def test_constructor_orbit_tabulated_sucess(
tabulated = Tabulated(
states=test_states,
initial_revolution_number=1,
interpolation_type=Tabulated.InterpolationType.CubicSpline,
interpolation_type=Interpolator.Type.CubicSpline,
)

orbit: Orbit = Orbit(tabulated, earth)
Expand All @@ -267,26 +269,24 @@ def test_get_interpolation_type(self, test_states: list[State]):
tabulated = Tabulated(
states=test_states,
initial_revolution_number=1,
interpolation_type=Tabulated.InterpolationType.CubicSpline,
interpolation_type=Interpolator.Type.CubicSpline,
)

assert (
tabulated.get_interpolation_type() == Tabulated.InterpolationType.CubicSpline
)
assert tabulated.get_interpolation_type() == Interpolator.Type.CubicSpline

@pytest.mark.parametrize(
"interpolation_type,error_tolerance",
(
(Tabulated.InterpolationType.Linear, 420.0),
(Tabulated.InterpolationType.CubicSpline, 5e-3),
(Tabulated.InterpolationType.BarycentricRational, 5e-2),
(Interpolator.Type.Linear, 420.0),
(Interpolator.Type.CubicSpline, 5e-3),
(Interpolator.Type.BarycentricRational, 5e-2),
),
)
def test_calculate_state_at_success(
self,
test_states: list[State],
reference_states: list[State],
interpolation_type: Tabulated.InterpolationType,
interpolation_type: Interpolator.Type,
error_tolerance: float,
):
tabulated = Tabulated(
Expand Down Expand Up @@ -314,16 +314,16 @@ def test_calculate_state_at_success(
@pytest.mark.parametrize(
"interpolation_type,error_tolerance",
(
(Tabulated.InterpolationType.Linear, 420.0),
(Tabulated.InterpolationType.CubicSpline, 5e-3),
(Tabulated.InterpolationType.BarycentricRational, 5e-2),
(Interpolator.Type.Linear, 420.0),
(Interpolator.Type.CubicSpline, 5e-3),
(Interpolator.Type.BarycentricRational, 5e-2),
),
)
def test_calculate_states_at_success(
self,
test_states: list[State],
reference_states: list[State],
interpolation_type: Tabulated.InterpolationType,
interpolation_type: Interpolator.Type,
error_tolerance: float,
):
tabulated = Tabulated(
Expand Down Expand Up @@ -355,13 +355,13 @@ def test_calculate_states_at_success(
@pytest.mark.parametrize(
"interpolation_type",
(
(Tabulated.InterpolationType.Linear),
(Tabulated.InterpolationType.CubicSpline),
(Tabulated.InterpolationType.BarycentricRational),
(Interpolator.Type.Linear),
(Interpolator.Type.CubicSpline),
(Interpolator.Type.BarycentricRational),
),
)
def test_calculate_state_at_failure(
self, test_states: list[State], interpolation_type: Tabulated.InterpolationType
self, test_states: list[State], interpolation_type: Interpolator.Type
):
tabulated = Tabulated(
states=test_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
#include <OpenSpaceToolkit/Core/Type/Shared.hpp>

#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator.hpp>
#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator/BarycentricRational.hpp>
#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator/CubicSpline.hpp>
#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator/Linear.hpp>
#include <OpenSpaceToolkit/Mathematics/Object/Vector.hpp>

#include <OpenSpaceToolkit/Physics/Time/Instant.hpp>
#include <OpenSpaceToolkit/Physics/Time/Interval.hpp>
Expand All @@ -37,7 +35,7 @@ using ostk::core::type::Index;
using ostk::core::type::Shared;
using ostk::core::type::Size;

using ostk::mathematics::curvefitting::interpolator::Interpolator;
using ostk::mathematics::curvefitting::Interpolator;
using ostk::mathematics::object::MatrixXd;
using ostk::mathematics::object::VectorXd;

Expand All @@ -48,7 +46,7 @@ using ostk::physics::time::Scale;
using ostk::astrodynamics::trajectory::Model;
using ostk::astrodynamics::trajectory::State;

#define DEFAULT_TABULATED_INTERPOLATION_TYPE Tabulated::InterpolationType::Linear
#define DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE Interpolator::Type::Linear

/// @brief Tabulated trajectory model
///
Expand All @@ -58,18 +56,9 @@ using ostk::astrodynamics::trajectory::State;
class Tabulated : public virtual Model
{
public:
enum class InterpolationType
{

Linear,
BarycentricRational,
CubicSpline

};

Tabulated(
const Array<State>& aStateArray,
const InterpolationType& anInterpolationType = DEFAULT_TABULATED_INTERPOLATION_TYPE
const Interpolator::Type& anInterpolationType = DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE
);

virtual Tabulated* clone() const override;
Expand All @@ -84,7 +73,7 @@ class Tabulated : public virtual Model

Interval getInterval() const;

InterpolationType getInterpolationType() const;
Interpolator::Type getInterpolationType() const;

State getFirstState() const;

Expand All @@ -106,10 +95,7 @@ class Tabulated : public virtual Model
private:
State firstState_ = State::Undefined();
State lastState_ = State::Undefined();

InterpolationType interpolationType_;

Array<Shared<Interpolator>> interpolators_ = Array<Shared<Interpolator>>::Empty();
Array<Shared<const Interpolator>> interpolators_;
};

} // namespace model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <OpenSpaceToolkit/Core/Container/Array.hpp>
#include <OpenSpaceToolkit/Core/Type/Integer.hpp>

#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator.hpp>

#include <OpenSpaceToolkit/Physics/Time/Instant.hpp>

#include <OpenSpaceToolkit/Astrodynamics/Trajectory/Model.hpp>
Expand All @@ -27,6 +29,8 @@ namespace model
using ostk::core::container::Array;
using ostk::core::type::Integer;

using ostk::mathematics::curvefitting::Interpolator;

using ostk::physics::time::Instant;

using ostk::astrodynamics::trajectory::State;
Expand All @@ -37,7 +41,7 @@ class Tabulated : public virtual trajectory::orbit::Model, public trajectory::mo
Tabulated(
const Array<State>& aStateArray,
const Integer& anInitialRevolutionNumber,
const InterpolationType& aType = DEFAULT_TABULATED_INTERPOLATION_TYPE
const Interpolator::Type& aType = DEFAULT_TABULATED_TRAJECTORY_INTERPOLATION_TYPE
);

virtual Tabulated* clone() const override;
Expand Down
34 changes: 7 additions & 27 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Model/Tabulated.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@ namespace trajectory
namespace model
{

Tabulated::Tabulated(const Array<State>& aStateArray, const InterpolationType& anInterpolationType)
: Model(),
interpolationType_(anInterpolationType)
Tabulated::Tabulated(const Array<State>& aStateArray, const Interpolator::Type& anInterpolationType)
: Model()
{
using ostk::mathematics::curvefitting::interpolator::BarycentricRational;
using ostk::mathematics::curvefitting::interpolator::CubicSpline;
using ostk::mathematics::curvefitting::interpolator::Linear;

if (aStateArray.getSize() < 2)
{
return;
Expand Down Expand Up @@ -56,23 +51,7 @@ Tabulated::Tabulated(const Array<State>& aStateArray, const InterpolationType& a

for (Index i = 0; i < Size(coordinates.cols()); ++i)
{
if (interpolationType_ == Tabulated::InterpolationType::CubicSpline)
{
interpolators_.add(std::make_shared<CubicSpline>(CubicSpline(timestamps, coordinates.col(i))));
}
else if (interpolationType_ == Tabulated::InterpolationType::BarycentricRational)
{
interpolators_.add(std::make_shared<BarycentricRational>(BarycentricRational(timestamps, coordinates.col(i))
));
}
else if (interpolationType_ == Tabulated::InterpolationType::Linear)
{
interpolators_.add(std::make_shared<Linear>(Linear(timestamps, coordinates.col(i))));
}
else
{
throw ostk::core::error::runtime::Wrong("InterpolationType");
}
interpolators_.add(Interpolator::GenerateInterpolator(anInterpolationType, timestamps, coordinates.col(i)));
}
}

Expand All @@ -88,7 +67,7 @@ bool Tabulated::operator==(const Tabulated& aTabulatedModel) const
return false;
}

return interpolationType_ == aTabulatedModel.getInterpolationType() &&
return this->getInterpolationType() == aTabulatedModel.getInterpolationType() &&
firstState_ == aTabulatedModel.getFirstState() && lastState_ == aTabulatedModel.getLastState();
}

Expand Down Expand Up @@ -119,14 +98,15 @@ Interval Tabulated::getInterval() const
return Interval::Closed(firstState_.accessInstant(), lastState_.accessInstant());
}

Tabulated::InterpolationType Tabulated::getInterpolationType() const
Interpolator::Type Tabulated::getInterpolationType() const
{
if (!this->isDefined())
{
throw ostk::core::error::runtime::Undefined("Tabulated");
}

return interpolationType_;
// Since all interpolators are of the same type, we can just return the type of the first one.
return interpolators_[0]->getInterpolationType();
}

State Tabulated::getFirstState() const
Expand Down
7 changes: 3 additions & 4 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/// Apache License 2.0

#include <iostream>

#include <OpenSpaceToolkit/Core/Error.hpp>
#include <OpenSpaceToolkit/Core/Utility.hpp>

#include <OpenSpaceToolkit/Mathematics/CurveFitting/Interpolator.hpp>
#include <OpenSpaceToolkit/Mathematics/Geometry/3D/Transformation/Rotation/RotationMatrix.hpp>
#include <OpenSpaceToolkit/Mathematics/Geometry/3D/Transformation/Rotation/RotationVector.hpp>

Expand Down Expand Up @@ -36,6 +35,7 @@ using ostk::core::type::Uint8;
using ostk::core::type::Real;
using ostk::core::type::Index;

using ostk::mathematics::curvefitting::Interpolator;
using ostk::mathematics::object::Vector3d;

using ostk::physics::time::Duration;
Expand Down Expand Up @@ -1053,8 +1053,7 @@ Array<Pair<Index, Pass>> Orbit::ComputePasses(const Array<State>& aStateArray, c
}
}

const model::Tabulated tabulated =
model::Tabulated(aStateArray, model::Tabulated::InterpolationType::BarycentricRational);
const model::Tabulated tabulated = model::Tabulated(aStateArray, Interpolator::Type::BarycentricRational);

const Instant& epoch = aStateArray.accessFirst().accessInstant();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace model
Tabulated::Tabulated(
const Array<State>& aStateArray,
const Integer& anInitialRevolutionNumber,
const InterpolationType& anInterpolationType
const Interpolator::Type& anInterpolationType
)
: trajectory::orbit::Model(),
trajectory::model::Tabulated(aStateArray, anInterpolationType),
Expand Down
Loading

0 comments on commit 6f0159e

Please sign in to comment.