Skip to content

Commit

Permalink
feat: add custom logging option for conditional solver (#252)
Browse files Browse the repository at this point in the history
* feat: add custom logging option for conditional solver

* feat: fix python tests
  • Loading branch information
vishwa2710 authored Oct 17, 2023
1 parent c7eb9fd commit 8dc35fa
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_State_NumericalSolver(pyb
using ostk::astro::EventCondition;
using ostk::astro::trajectory::State;
using ostk::astro::trajectory::state::NumericalSolver;
using ostk::astro::RootSolver;

typedef std::function<MathNumericalSolver::StateVector(
const MathNumericalSolver::StateVector& x, MathNumericalSolver::StateVector& dxdt, const double t
Expand All @@ -40,12 +41,14 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_State_NumericalSolver(pyb
const NumericalSolver::StepperType&,
const Real&,
const Real&,
const Real&>(),
const Real&,
const RootSolver&>(),
arg("log_type"),
arg("stepper_type"),
arg("time_step"),
arg("relative_tolerance"),
arg("absolute_tolerance")
arg("absolute_tolerance"),
arg("root_solver") = RootSolver::Default()
)

.def(self == self)
Expand Down Expand Up @@ -135,7 +138,8 @@ inline void OpenSpaceToolkitAstrodynamicsPy_Trajectory_State_NumericalSolver(pyb

.def_static("default", &NumericalSolver::Default)
.def_static("undefined", &NumericalSolver::Undefined)
.def_static("default_conditional", &NumericalSolver::DefaultConditional)
.def_static("default_conditional", &NumericalSolver::DefaultConditional, arg("state_logger") = nullptr)
.def_static("conditional", &NumericalSolver::Conditional)

;
}
Expand Down
124 changes: 88 additions & 36 deletions bindings/python/test/trajectory/state/test_numerical_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,40 +67,60 @@ def custom_condition() -> RealCondition:


@pytest.fixture
def numerical_solver_default_inputs() -> (
tuple[NumericalSolver.LogType, NumericalSolver.StepperType, float, float, float]
):
log_type = NumericalSolver.LogType.NoLog
stepper_type = NumericalSolver.StepperType.RungeKuttaCashKarp54
initial_time_step = 5.0
relative_tolerance = 1.0e-15
absolute_tolerance = 1.0e-15

return (
log_type,
stepper_type,
initial_time_step,
relative_tolerance,
absolute_tolerance,
)
def log_type() -> NumericalSolver.LogType:
return NumericalSolver.LogType.NoLog


@pytest.fixture
def numerical_solver(numerical_solver_default_inputs) -> NumericalSolver:
return NumericalSolver(*numerical_solver_default_inputs)
def stepper_type() -> NumericalSolver.StepperType:
return NumericalSolver.StepperType.RungeKuttaCashKarp54


@pytest.fixture
def numerical_solver_conditional() -> NumericalSolver:
def initial_time_step() -> float:
return 5.0


@pytest.fixture
def relative_tolerance() -> float:
return 1.0e-15


@pytest.fixture
def absolute_tolerance() -> float:
return 1.0e-15


@pytest.fixture
def state_logger() -> callable:
def log_state(state: State) -> None:
print(state.get_coordinates())

return log_state


@pytest.fixture
def numerical_solver(
log_type: NumericalSolver.LogType,
stepper_type: NumericalSolver.StepperType,
initial_time_step: float,
relative_tolerance: float,
absolute_tolerance: float,
) -> NumericalSolver:
return NumericalSolver(
NumericalSolver.LogType.NoLog,
NumericalSolver.StepperType.RungeKuttaDopri5,
5.0,
1.0e-15,
1.0e-15,
log_type=log_type,
stepper_type=stepper_type,
time_step=initial_time_step,
relative_tolerance=relative_tolerance,
absolute_tolerance=absolute_tolerance,
)


@pytest.fixture
def numerical_solver_conditional() -> NumericalSolver:
return NumericalSolver.default_conditional()


class TestNumericalSolver:
def test_constructors(self, numerical_solver: NumericalSolver):
assert numerical_solver is not None
Expand All @@ -113,19 +133,13 @@ def test_comparators(self, numerical_solver: NumericalSolver):

def test_get_types(
self,
numerical_solver_default_inputs: tuple[
NumericalSolver.LogType, NumericalSolver.StepperType, float, float, float
],
log_type: NumericalSolver.LogType,
stepper_type: NumericalSolver.StepperType,
initial_time_step: float,
relative_tolerance: float,
absolute_tolerance: float,
numerical_solver: NumericalSolver,
):
(
log_type,
stepper_type,
initial_time_step,
relative_tolerance,
absolute_tolerance,
) = numerical_solver_default_inputs

assert numerical_solver.get_log_type() == log_type
assert numerical_solver.get_stepper_type() == stepper_type
assert numerical_solver.get_time_step() == initial_time_step
Expand Down Expand Up @@ -231,12 +245,50 @@ def test_integrate_time_with_condition(
assert 5e-9 >= abs(state_vector[0] - math.sin(time))
assert 5e-9 >= abs(state_vector[1] - math.cos(time))

def test_integrate_conditional_with_logger(
self,
initial_state: State,
state_logger: callable,
custom_condition: RealCondition,
capsys,
):
numerical_solver: NumericalSolver = NumericalSolver.conditional(
5.0,
1.0e-15,
1.0e-15,
state_logger,
)
end_time: float = initial_state.get_instant() + Duration.seconds(10.0)

numerical_solver.integrate_time(
initial_state, end_time, oscillator, custom_condition
)

captured = capsys.readouterr()

assert captured.out != ""

def test_default(self):
assert NumericalSolver.default() is not None

def test_default_conditional(self):
def test_default_conditional(self, state_logger):
assert NumericalSolver.default_conditional() is not None
assert NumericalSolver.default_conditional(state_logger) is not None

def test_undefined(self):
assert NumericalSolver.undefined() is not None
assert NumericalSolver.undefined().is_defined() is False

def test_conditional(
self,
initial_time_step: float,
relative_tolerance: float,
absolute_tolerance: float,
state_logger,
):
assert (
NumericalSolver.conditional(
initial_time_step, relative_tolerance, absolute_tolerance, state_logger
)
is not None
)
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ using MathNumericalSolver = ostk::math::solvers::NumericalSolver;
class NumericalSolver : public MathNumericalSolver
{
public:
/// @brief Structure to hold the condition solution.

struct ConditionSolution
{
State state; ///< Final state after integration.
bool conditionIsSatisfied; ///< Whether the condition is met.
Size iterationCount; ///< Number of iterations performed.
bool rootSolverHasConverged; ///< Whether the root solver has converged.
};

/// @brief Constructor
///
/// @code
Expand All @@ -63,16 +73,6 @@ class NumericalSolver : public MathNumericalSolver
const RootSolver& aRootSolver = RootSolver::Default()
);

/// @brief Structure to hold the condition solution.

struct ConditionSolution
{
State state; ///< Final state after integration.
bool conditionIsSatisfied; ///< Whether the condition is met.
Size iterationCount; ///< Number of iterations performed.
bool rootSolverHasConverged; ///< Whether the root solver has converged.
};

/// @brief Access observed states
///
/// @code
Expand Down Expand Up @@ -155,9 +155,26 @@ class NumericalSolver : public MathNumericalSolver

/// @brief Default conditional
///
/// @return A default conditional numerical solver
/// @param [in] stateLogger A function that takes a `State` object and logs. Defaults to `nullptr`.
/// @return A default conditional numerical solver.

static NumericalSolver DefaultConditional(const std::function<void(const State&)>& stateLogger = nullptr);

static NumericalSolver DefaultConditional();
/// @brief Create a conditional numerical solver.
///
/// @param [in] aTimeStep The initial time step to use.
/// @param [in] aRelativeTolerance The relative tolerance to use.
/// @param [in] anAbsoluteTolerance The absolute tolerance to use.
/// @param [in] stateLogger A function that takes a `State` object and logs.
///
/// @return A conditional numerical solver.

static NumericalSolver Conditional(
const Real& aTimeStep,
const Real& aRelativeTolerance,
const Real& anAbsoluteTolerance,
const std::function<void(const State&)>& stateLogger
);

/// Delete undesired methods from parent

Expand Down Expand Up @@ -194,6 +211,37 @@ class NumericalSolver : public MathNumericalSolver
private:
RootSolver rootSolver_;
Array<State> observedStates_;
std::function<void(const State&)> stateLogger_;

/// @brief Constructor
///
/// @code
/// NumericalSolver numericalSolver = { aLogType, aStepperType, aTimeStep,
/// aRelativeTolerance, anAbsoluteTolerance };
/// @endcode
///
/// @param [in] aLogType An enum indicating the amount of verbosity wanted to be logged during
/// numerical integration
/// @param [in] aStepperType An enum indicating the type of numerical stepper used to perform
/// integration
/// @param [in] aTimeStep A number indicating the initial guess time step the numerical solver will
/// take
/// @param [in] aRelativeTolerance A number indicating the relative integration tolerance
/// @param [in] anAbsoluteTolerance A number indicating the absolute integration tolerance
/// @param [in] aRootSolver A root solver to be used to solve the event condition
/// @param [in] stateLogger A function that takes a `State` object and logs

NumericalSolver(
const NumericalSolver::LogType& aLogType,
const NumericalSolver::StepperType& aStepperType,
const Real& aTimeStep,
const Real& aRelativeTolerance,
const Real& anAbsoluteTolerance,
const RootSolver& aRootSolver,
const std::function<void(const State&)>& stateLogger
);

void observeState(const State& aState);
};

} // namespace state
Expand Down
Loading

0 comments on commit 8dc35fa

Please sign in to comment.