Skip to content

Commit

Permalink
feat: pass back conditional solution information for segment and sequ…
Browse files Browse the repository at this point in the history
…ences
  • Loading branch information
vishwa2710 committed Oct 6, 2023
1 parent 4465595 commit 5ff075e
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Segment(pybind11::module&
.def_readonly("name", &Segment::Solution::name)
.def_readonly("dynamics", &Segment::Solution::dynamics)
.def_readonly("states", &Segment::Solution::states)
.def_readonly("condition_is_satisfied", &Segment::Solution::conditionIsSatisfied)

;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_Sequence(pybind11::module
class_<Sequence::Solution>(aModule, "SequenceSolution")

.def_readonly("segment_solutions", &Sequence::Solution::segmentSolutions)
.def_readonly("sequence_is_complete", &Sequence::Solution::sequenceIsComplete)

.def("get_states", &Sequence::Solution::getStates)

;
Expand Down
111 changes: 82 additions & 29 deletions bindings/python/test/trajectory/state/test_numerical_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,27 +67,55 @@ def custom_condition() -> RealCondition:


@pytest.fixture
def numerical_solver_default_inputs() -> (
tuple[NumericalSolver.LogType, NumericalSolver.StepperType, float, float, float]
):
log_type = NumericalSolver.LogType.NoLog
stepper_type = NumericalSolver.StepperType.RungeKuttaCashKarp54
initial_time_step = 5.0
relative_tolerance = 1.0e-15
absolute_tolerance = 1.0e-15

return (
log_type,
stepper_type,
initial_time_step,
relative_tolerance,
absolute_tolerance,
)
def log_type() -> NumericalSolver.LogType:
return NumericalSolver.LogType.NoLog


@pytest.fixture
def stepper_type() -> NumericalSolver.StepperType:
return NumericalSolver.StepperType.RungeKuttaCashKarp54


@pytest.fixture
def initial_time_step() -> float:
return 5.0


@pytest.fixture
def relative_tolerance() -> float:
return 1.0e-15


@pytest.fixture
def numerical_solver(numerical_solver_default_inputs) -> NumericalSolver:
return NumericalSolver(*numerical_solver_default_inputs)
def absolute_tolerance() -> float:
return 1.0e-15


@pytest.fixture
def state_logger() -> callable:
def log_state(state: State) -> None:
print(state.get_coordinates())

return log_state


@pytest.fixture
def numerical_solver(
log_type: NumericalSolver.LogType,
stepper_type: NumericalSolver.StepperType,
initial_time_step: float,
relative_tolerance: float,
absolute_tolerance: float,
state_logger: callable,
) -> NumericalSolver:
return NumericalSolver(
log_type=log_type,
stepper_type=stepper_type,
time_step=initial_time_step,
relative_tolerance=relative_tolerance,
absolute_tolerance=absolute_tolerance,
state_logger=state_logger,
)


@pytest.fixture
Expand All @@ -113,19 +141,13 @@ def test_comparators(self, numerical_solver: NumericalSolver):

def test_get_types(
self,
numerical_solver_default_inputs: tuple[
NumericalSolver.LogType, NumericalSolver.StepperType, float, float, float
],
log_type: NumericalSolver.LogType,
stepper_type: NumericalSolver.StepperType,
initial_time_step: float,
relative_tolerance: float,
absolute_tolerance: float,
numerical_solver: NumericalSolver,
):
(
log_type,
stepper_type,
initial_time_step,
relative_tolerance,
absolute_tolerance,
) = numerical_solver_default_inputs

assert numerical_solver.get_log_type() == log_type
assert numerical_solver.get_stepper_type() == stepper_type
assert numerical_solver.get_time_step() == initial_time_step
Expand Down Expand Up @@ -231,6 +253,29 @@ def test_integrate_time_with_condition(
assert 5e-9 >= abs(state_vector[0] - math.sin(time))
assert 5e-9 >= abs(state_vector[1] - math.cos(time))

def test_integrate_conditional_with_logger(
self,
initial_state: State,
state_logger: callable,
custom_condition: RealCondition,
capsys,
):
numerical_solver: NumericalSolver = NumericalSolver.conditional(
5.0,
1.0e-15,
1.0e-15,
state_logger,
)
end_time: float = initial_state.get_instant() + Duration.seconds(10.0)

numerical_solver.integrate_time(
initial_state, end_time, oscillator, custom_condition
)

captured = capsys.readouterr()

assert captured.out != ""

def test_default(self):
assert NumericalSolver.default() is not None

Expand All @@ -240,3 +285,11 @@ def test_default_conditional(self):
def test_undefined(self):
assert NumericalSolver.undefined() is not None
assert NumericalSolver.undefined().is_defined() is False

def test_conditional(self):
assert (
NumericalSolver.conditional(
initial_time_step, relative_tolerance, absolute_tolerance, state_logger
)
is not None
)
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,17 @@ class Propagator

/// @brief Calculate the state subject to an Event Condition, given initial state and maximum end time
/// @code
/// State state = propagator.calculateStateAt(aState, anInstant, anEventCondition);
/// NumericalSolver::ConditionSolution state = propagator.calculateStateAt(aState, anInstant,
/// anEventCondition);
/// @endcode
/// @param [in] aState An initial state
/// @param [in] anInstant An instant
/// @param [in] anEventCondition An event condition
/// @return State
/// @return NumericalSolver::ConditionSolution

State calculateStateAt(const State& aState, const Instant& anInstant, const EventCondition& anEventCondition) const;
NumericalSolver::ConditionSolution calculateStateAt(
const State& aState, const Instant& anInstant, const EventCondition& anEventCondition
) const;

/// @brief Calculate the states at an array of instants, given an initial state
/// @brief Can only be used with sorted instants array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ class Segment
struct Solution
{
public:
Solution(const String& aName, const Array<Shared<Dynamics>>& aDynamicsArray, const Array<State>& aStates);
Solution(
const String& aName,
const Array<Shared<Dynamics>>& aDynamicsArray,
const Array<State>& aStates,
const bool& aConditionIsSatisfied
);

String name; /// Name of the segment.
Array<Shared<Dynamics>> dynamics; /// List of dynamics used.
Array<State> states; /// Array of states for the segment.
bool conditionIsSatisfied; /// True if the event condition is satisfied.
};

/// @brief Output stream operator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Sequence
struct Solution
{
Array<Segment::Solution> segmentSolutions;
bool sequenceIsComplete;

Array<State> getStates() const;
};
Expand Down
9 changes: 2 additions & 7 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ State Propagator::calculateStateAt(const State& aState, const Instant& anInstant
);
}

State Propagator::calculateStateAt(
NumericalSolver::ConditionSolution Propagator::calculateStateAt(
const State& aState, const Instant& anInstant, const EventCondition& anEventCondition
) const
{
Expand All @@ -182,12 +182,7 @@ State Propagator::calculateStateAt(
anEventCondition
);

if (!conditionSolution.conditionIsSatisfied)
{
throw ostk::core::error::RuntimeError("Condition not satisfied.");
}

return conditionSolution.state;
return conditionSolution;
}

Array<State> Propagator::calculateStatesAt(const State& aState, const Array<Instant>& anInstantArray) const
Expand Down
12 changes: 8 additions & 4 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Segment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ using ostk::physics::time::Duration;
using ostk::astro::trajectory::Propagator;

Segment::Solution::Solution(
const String& aName, const Array<Shared<Dynamics>>& aDynamicsArray, const Array<State>& aStates
const String& aName,
const Array<Shared<Dynamics>>& aDynamicsArray,
const Array<State>& aStates,
const bool& aConditionIsSatisfied
)
: name(aName),
dynamics(aDynamicsArray),
states(aStates)
states(aStates),
conditionIsSatisfied(aConditionIsSatisfied)
{
}

Expand Down Expand Up @@ -108,14 +112,14 @@ Segment::Solution Segment::solve(const State& aState, const Duration& maximumPro

const Instant startInstant = aState.getInstant();

// TBI: Handle the case where the condition is not met
const State finalState =
const NumericalSolver::ConditionSolution conditionSolution =
propagator.calculateStateAt(aState, startInstant + maximumPropagationDuration, *eventCondition_);

return {
name_,
dynamics_,
propagator.accessNumericalSolver().accessObservedStates(),
conditionSolution.conditionIsSatisfied,
};
}

Expand Down
7 changes: 6 additions & 1 deletion src/OpenSpaceToolkit/Astrodynamics/Trajectory/Sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,16 @@ Sequence::Solution Sequence::solve(const State& aState) const

segmentSolutions.add(segmentSolution);

if (!segmentSolution.conditionIsSatisfied)
{
return {segmentSolutions, false};
}

initialState = segmentSolution.states.accessLast();
}
}

return {segmentSolutions};
return {segmentSolutions, true};
}

void Sequence::print(std::ostream& anOutputStream, bool displayDecorator) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,12 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Models_Propagator, Calcul

const Propagator propagator = {defaultRKD5_, defaultDynamics_};

const State endState = propagator.calculateStateAt(state, endInstant, condition);
const NumericalSolver::ConditionSolution conditionSolution =
propagator.calculateStateAt(state, endInstant, condition);

const State endState = conditionSolution.state;

EXPECT_TRUE(conditionSolution.conditionIsSatisfied);
EXPECT_TRUE(endState.accessInstant() < endInstant);
EXPECT_LT((endState.accessInstant() - condition.getInstant()).inSeconds(), 1e-7);

Expand All @@ -432,7 +436,7 @@ TEST_F(OpenSpaceToolkit_Astrodynamics_Trajectory_Orbit_Models_Propagator, Calcul
state.accessInstant() + Duration::Seconds(7000.0),
};

EXPECT_ANY_THROW(propagator.calculateStateAt(state, endInstant, failureCondition));
EXPECT_FALSE(propagator.calculateStateAt(state, endInstant, failureCondition).conditionIsSatisfied);
}
}

Expand Down

0 comments on commit 5ff075e

Please sign in to comment.