Skip to content

Commit

Permalink
feat: return conditional solution for segment and sequences (#233)
Browse files Browse the repository at this point in the history
* feat: return conditional solution from segments/sequences

* feat: add tests

* feat: update propagator test

* feat: address feedback
  • Loading branch information
vishwa2710 authored Oct 9, 2023
1 parent 4465595 commit 92cbe01
Show file tree
Hide file tree
Showing 15 changed files with 110 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,10 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Propagator(pybind11::modu
.def("add_dynamics", &Propagator::addDynamics, arg("dynamics"))
.def("clear_dynamics", &Propagator::clearDynamics)

.def("calculate_state_at", &Propagator::calculateStateAt, arg("state"), arg("instant"))
.def(
"calculate_state_at",
overload_cast<const State&, const Instant&>(&Propagator::calculateStateAt, const_),
arg("state"),
arg("instant")
)
.def(
"calculate_state_at",
overload_cast<const State&, const Instant&, const EventCondition&>(&Propagator::calculateStateAt, const_),
"calculate_state_to_condition",
&Propagator::calculateStateToCondition,
arg("state"),
arg("instant"),
arg("event_condition")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Segment(pybind11::module&
.def_readonly("name", &Segment::Solution::name)
.def_readonly("dynamics", &Segment::Solution::dynamics)
.def_readonly("states", &Segment::Solution::states)
.def_readonly("condition_is_satisfied", &Segment::Solution::conditionIsSatisfied)

;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Sequence(pybind11::module
class_<Sequence::Solution>(aModule, "SequenceSolution")

.def_readonly("segment_solutions", &Sequence::Solution::segmentSolutions)
.def_readonly("execution_is_complete", &Sequence::Solution::executionIsComplete)

.def("get_states", &Sequence::Solution::getStates)

;
Expand Down
17 changes: 12 additions & 5 deletions bindings/python/test/trajectory/test_propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_calculate_state_at(self, propagator: Propagator, state: State):
)
assert propagator_state.get_instant() == instant

def test_calculate_state_at(
def test_calculate_state_to_condition(
self,
conditional_numerical_solver: NumericalSolver,
dynamics: list[Dynamics],
Expand All @@ -310,10 +310,15 @@ def test_calculate_state_at(

instant: Instant = Instant.date_time(DateTime(2018, 1, 1, 0, 10, 0), Scale.UTC)

propagator_state = propagator.calculate_state_at(state, instant, event_condition)
solution = propagator.calculate_state_to_condition(
state=state,
instant=instant,
event_condition=event_condition,
)

assert solution.condition_is_satisfied
assert pytest.approx(42.0, abs=1e-3) == float(
(propagator_state.get_instant() - state.get_instant()).in_seconds()
(solution.state.get_instant() - state.get_instant()).in_seconds()
)

def test_calculate_states_at(self, propagator: Propagator, state: State):
Expand All @@ -336,7 +341,8 @@ def test_calculate_states_at_with_drag(
state_low_altitude: State,
):
propagator: Propagator = Propagator(
numerical_solver, dynamics + [atmospheric_drag]
numerical_solver,
dynamics + [atmospheric_drag],
)

instant_array = [
Expand All @@ -356,7 +362,8 @@ def test_calculate_states_at_with_thrust(
state: State,
):
propagator: Propagator = Propagator(
numerical_solver, dynamics + [constant_thrust]
numerical_solver,
dynamics + [constant_thrust],
)

instant_array = [
Expand Down
1 change: 1 addition & 0 deletions bindings/python/test/trajectory/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,4 @@ def test_solve(
== 0.0
)
assert len(solution.states) > 0
assert solution.condition_is_satisfied
1 change: 1 addition & 0 deletions bindings/python/test/trajectory/test_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,4 @@ def test_solve(self, state: State, sequence: Sequence, segments: list[Segment]):
assert len(solution.segment_solutions) == len(segments)

assert solution.get_states() is not None
assert solution.execution_is_complete
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,17 @@ class Propagator

/// @brief Calculate the state subject to an Event Condition, given initial state and maximum end time
/// @code
/// State state = propagator.calculateStateAt(aState, anInstant, anEventCondition);
/// NumericalSolver::ConditionSolution state = propagator.calculateStateAt(aState, anInstant,
/// anEventCondition);
/// @endcode
/// @param [in] aState An initial state
/// @param [in] anInstant An instant
/// @param [in] anEventCondition An event condition
/// @return State
/// @return NumericalSolver::ConditionSolution

State calculateStateAt(const State& aState, const Instant& anInstant, const EventCondition& anEventCondition) const;
NumericalSolver::ConditionSolution calculateStateToCondition(
const State& aState, const Instant& anInstant, const EventCondition& anEventCondition
) const;

/// @brief Calculate the states at an array of instants, given an initial state
/// @brief Can only be used with sorted instants array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ class Segment
struct Solution
{
public:
Solution(const String& aName, const Array<Shared<Dynamics>>& aDynamicsArray, const Array<State>& aStates);
Solution(
const String& aName,
const Array<Shared<Dynamics>>& aDynamicsArray,
const Array<State>& aStates,
const bool& aConditionIsSatisfied
);

String name; /// Name of the segment.
Array<Shared<Dynamics>> dynamics; /// List of dynamics used.
Array<State> states; /// Array of states for the segment.
bool conditionIsSatisfied; /// True if the event condition is satisfied.
};

/// @brief Output stream operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Sequence
struct Solution
{
Array<Segment::Solution> segmentSolutions;
bool executionIsComplete;

Array<State> getStates() const;
};
Expand Down
9 changes: 2 additions & 7 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ State Propagator::calculateStateAt(const State& aState, const Instant& anInstant
);
}

State Propagator::calculateStateAt(
NumericalSolver::ConditionSolution Propagator::calculateStateToCondition(
const State& aState, const Instant& anInstant, const EventCondition& anEventCondition
) const
{
Expand All @@ -182,12 +182,7 @@ State Propagator::calculateStateAt(
anEventCondition
);

if (!conditionSolution.conditionIsSatisfied)
{
throw ostk::core::error::RuntimeError("Condition not satisfied.");
}

return conditionSolution.state;
return conditionSolution;
}

Array<State> Propagator::calculateStatesAt(const State& aState, const Array<Instant>& anInstantArray) const
Expand Down
14 changes: 9 additions & 5 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ using ostk::physics::time::Duration;
using ostk::astro::trajectory::Propagator;

Segment::Solution::Solution(
const String& aName, const Array<Shared<Dynamics>>& aDynamicsArray, const Array<State>& aStates
const String& aName,
const Array<Shared<Dynamics>>& aDynamicsArray,
const Array<State>& aStates,
const bool& aConditionIsSatisfied
)
: name(aName),
dynamics(aDynamicsArray),
states(aStates)
states(aStates),
conditionIsSatisfied(aConditionIsSatisfied)
{
}

Expand Down Expand Up @@ -108,14 +112,14 @@ Segment::Solution Segment::solve(const State& aState, const Duration& maximumPro

const Instant startInstant = aState.getInstant();

// TBI: Handle the case where the condition is not met
const State finalState =
propagator.calculateStateAt(aState, startInstant + maximumPropagationDuration, *eventCondition_);
const NumericalSolver::ConditionSolution conditionSolution =
propagator.calculateStateToCondition(aState, startInstant + maximumPropagationDuration, *eventCondition_);

return {
name_,
dynamics_,
propagator.accessNumericalSolver().accessObservedStates(),
conditionSolution.conditionIsSatisfied,
};
}

Expand Down
7 changes: 6 additions & 1 deletion src/OpenSpaceToolkit/Astrodynamics/Trajectory/Sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,16 @@ Sequence::Solution Sequence::solve(const State& aState) const

segmentSolutions.add(segmentSolution);

if (!segmentSolution.conditionIsSatisfied)
{
return {segmentSolutions, false};
}

initialState = segmentSolution.states.accessLast();
}
}

return {segmentSolutions};
return {segmentSolutions, true};
}

void Sequence::print(std::ostream& anOutputStream, bool displayDecorator) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,12 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Models_Propagator, Calcul

const Propagator propagator = {defaultRKD5_, defaultDynamics_};

const State endState = propagator.calculateStateAt(state, endInstant, condition);
const NumericalSolver::ConditionSolution conditionSolution =
propagator.calculateStateToCondition(state, endInstant, condition);

const State endState = conditionSolution.state;

EXPECT_TRUE(conditionSolution.conditionIsSatisfied);
EXPECT_TRUE(endState.accessInstant() < endInstant);
EXPECT_LT((endState.accessInstant() - condition.getInstant()).inSeconds(), 1e-7);

Expand All @@ -432,7 +436,7 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Models_Propagator, Calcul
state.accessInstant() + Duration::Seconds(7000.0),
};

EXPECT_ANY_THROW(propagator.calculateStateAt(state, endInstant, failureCondition));
EXPECT_FALSE(propagator.calculateStateToCondition(state, endInstant, failureCondition).conditionIsSatisfied);
}
}

Expand Down
23 changes: 22 additions & 1 deletion test/OpenSpaceToolkit/Astrodynamics/Trajectory/Segment.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <OpenSpaceToolkit/Astrodynamics/Dynamics/CentralBodyGravity.hpp>
#include <OpenSpaceToolkit/Astrodynamics/Dynamics/PositionDerivative.hpp>
#include <OpenSpaceToolkit/Astrodynamics/Dynamics/Thruster/ConstantThrust.hpp>
#include <OpenSpaceToolkit/Astrodynamics/EventCondition/COECondition.hpp>
#include <OpenSpaceToolkit/Astrodynamics/EventCondition/InstantCondition.hpp>
#include <OpenSpaceToolkit/Astrodynamics/Trajectory/Segment.hpp>
#include <OpenSpaceToolkit/Astrodynamics/Trajectory/State/NumericalSolver.hpp>
Expand All @@ -30,6 +31,7 @@ using ostk::physics::env::obj::celest::Earth;
using ostk::physics::coord::Frame;
using ostk::physics::coord::Position;
using ostk::physics::coord::Velocity;
using EarthGravitationalModel = ostk::physics::environment::gravitational::Earth;

using ostk::astro::trajectory::state::NumericalSolver;
using ostk::astro::Dynamics;
Expand All @@ -41,6 +43,8 @@ using ostk::astro::trajectory::LocalOrbitalFrameFactory;
using ostk::astro::dynamics::CentralBodyGravity;
using ostk::astro::dynamics::PositionDerivative;
using ostk::astro::eventcondition::InstantCondition;
using ostk::astro::eventcondition::COECondition;
using ostk::astro::eventcondition::RealCondition;
using ostk::astro::trajectory::State;

class OpenSpaceToolkit_Astrodynamics_Trajectory_TrajectorySegment : public ::testing::Test
Expand Down Expand Up @@ -170,13 +174,30 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_TrajectorySegment, StreamOperat
TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_TrajectorySegment, Solve)
{
{
Segment::Solution solution = defaultCoastSegment_.solve(defaultState_);
const Segment::Solution solution = defaultCoastSegment_.solve(defaultState_);

EXPECT_LT(
(solution.states.accessLast().getInstant() - defaultInstantCondition_->getInstant()).inSeconds(), 1e-7
);
EXPECT_TRUE(solution.states.getSize() > 0);
}

{
const Shared<RealCondition> eventCondition = std::make_shared<RealCondition>(COECondition::Eccentricity(
RealCondition::Criterion::AnyCrossing,
Frame::GCRF(),
0.5,
EarthGravitationalModel::EGM2008.gravitationalParameter_
));

const Segment segment =
Segment::Coast("SMA condition", eventCondition, defaultDynamics_, defaultNumericalSolver_);

const Segment::Solution solution = segment.solve(defaultState_, Duration::Minutes(1.0));

EXPECT_TRUE(solution.states.getSize() > 0);
EXPECT_FALSE(solution.conditionIsSatisfied);
}
}

TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_TrajectorySegment, Print)
Expand Down
44 changes: 31 additions & 13 deletions test/OpenSpaceToolkit/Astrodynamics/Trajectory/Sequence.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,24 +239,42 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Sequence, AddManeuverSegment)

TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Sequence, Solve)
{
Sequence::Solution solution = defaultSequence_.solve(defaultState_);
{
const Sequence::Solution solution = defaultSequence_.solve(defaultState_);

EXPECT_TRUE(
solution.segmentSolutions.getSize() == defaultSequence_.getSegments().getSize() * defaultRepetitionCount_
);
EXPECT_TRUE(
solution.segmentSolutions.getSize() == defaultSequence_.getSegments().getSize() * defaultRepetitionCount_
);

Size statesSize = 0;
for (const Segment::Solution& segmentSolution : solution.segmentSolutions)
{
EXPECT_TRUE(segmentSolution.states.getSize() > 0);
Size statesSize = 0;
for (const Segment::Solution& segmentSolution : solution.segmentSolutions)
{
EXPECT_TRUE(segmentSolution.states.getSize() > 0);

const Real targetAngle = defaultCondition_->getEvaluator()(segmentSolution.states.accessLast());
EXPECT_NEAR(targetAngle, defaultCondition_->getTargetAngle().inRadians(0.0, Real::TwoPi()), 1e-6);
const Real targetAngle = defaultCondition_->getEvaluator()(segmentSolution.states.accessLast());
EXPECT_NEAR(targetAngle, defaultCondition_->getTargetAngle().inRadians(0.0, Real::TwoPi()), 1e-6);

statesSize += segmentSolution.states.getSize();
statesSize += segmentSolution.states.getSize();
}

EXPECT_TRUE(solution.getStates().getSize() == statesSize);
}

EXPECT_TRUE(solution.getStates().getSize() == statesSize);
{
const Sequence sequence = {
defaultSegments_,
defaultRepetitionCount_,
defaultNumericalSolver_,
defaultDynamics_,
Duration::Seconds(1.0),
};

const Sequence::Solution solution = sequence.solve(defaultState_);

EXPECT_FALSE(solution.executionIsComplete);
EXPECT_TRUE(solution.segmentSolutions.getSize() == 1);
EXPECT_FALSE(solution.segmentSolutions[0].conditionIsSatisfied);
}
}

TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Sequence, Solve_2)
Expand Down Expand Up @@ -341,7 +359,7 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Sequence, Solve_2)
coordinatesBrokerSPtr,
};

Sequence::Solution solution = sequence.solve(state);
const Sequence::Solution solution = sequence.solve(state);

EXPECT_TRUE(solution.segmentSolutions.getSize() == 2 * defaultRepetitionCount_);
}
Expand Down

0 comments on commit 92cbe01

Please sign in to comment.