Skip to content

Commit

Permalink
feat: clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwa2710 committed Jun 18, 2024
1 parent 5f72941 commit bc0e81c
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 84 deletions.
1 change: 1 addition & 0 deletions bindings/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ open-space-toolkit-core~=3.0
open-space-toolkit-io~=3.0
open-space-toolkit-mathematics~=3.0
open-space-toolkit-physics~=7.0
numpy~=1.26
29 changes: 26 additions & 3 deletions bindings/python/test/trajectory/test_orbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,19 @@ def test_get_pass_with_revolution_number(self, orbit: Orbit):
assert isinstance(pass_, Pass)
assert pass_.is_defined()

assert orbit.get_pass_with_revolution_number(2, Duration.minutes(10.0)) is not None
assert (
orbit.get_pass_with_revolution_number(2, Duration.minutes(10.0)) is not None
)

def test_get_passes_within_interval(self, orbit: Orbit):
passes: list[Pass] = orbit.get_passes_within_interval(
Interval.closed(
Instant.date_time(DateTime(2018, 1, 1, 0, 0, 0), Scale.UTC),
Instant.date_time(DateTime(2018, 1, 1, 0, 10, 0), Scale.UTC),
)
)

assert len(passes) > 0

def test_undefined(self):
assert Orbit.undefined().is_defined() is False
Expand Down Expand Up @@ -169,6 +181,17 @@ def test_sun_synchronous(self, earth):
argument_of_latitude=Angle.degrees(50.0),
).is_defined()

def test_compute_passes(self, orbit: Orbit, states: list[State]):
passes: list[tuple[int, Pass]] = orbit.compute_passes(states, 1)
def test_compute_passes(self, states: list[State]):
passes: list[tuple[int, Pass]] = Orbit.compute_passes(states, 1)
assert passes is not None

def test_compute_passes_with_model(self, orbit: Orbit):
passes: list[tuple[int, Pass]] = Orbit.compute_passes_with_model(
model=orbit.access_kepler_model(),
interval=Interval.closed(
Instant.date_time(DateTime(2018, 1, 1, 0, 0, 0), Scale.UTC),
Instant.date_time(DateTime(2018, 1, 1, 0, 10, 0), Scale.UTC),
),
)

assert len(passes) > 0
163 changes: 83 additions & 80 deletions src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ Pass Orbit::getPassWithRevolutionNumber(const Integer& aRevolutionNumber, const

const std::lock_guard<std::mutex> lock {this->mutex_};

const auto getClosestPass = [this](const Integer& aRevolutionNumber) -> Pass
const auto getClosestPass = [this](const Integer& revolutionNumber) -> Pass
{
if (this->passMap_.empty())
{
Expand All @@ -190,23 +190,23 @@ Pass Orbit::getPassWithRevolutionNumber(const Integer& aRevolutionNumber, const

// exact revolution number exists

if (this->passMap_.count(aRevolutionNumber))
if (this->passMap_.count(revolutionNumber))
{
return this->passMap_.at(aRevolutionNumber);
return this->passMap_.at(revolutionNumber);
}

const auto lowerBoundMapIt = this->passMap_.lower_bound(aRevolutionNumber);
const auto lowerBoundMapIt = this->passMap_.lower_bound(revolutionNumber);

// Revolution number is greater than any existing revolution number in map
// {5, 6, 9, 10} -> aRevolutionNumber=12 -> return 10
// {5, 6, 9, 10} -> revolutionNumber=12 -> return 10

if (lowerBoundMapIt == this->passMap_.end())
{
return this->passMap_.rbegin()->second;
}

// Revolution number is lesser than any existing revolution number in map
// {5, 6, 9, 10} -> aRevolutionNumber=4 -> return 5
// {5, 6, 9, 10} -> revolutionNumber=4 -> return 5

if (lowerBoundMapIt == this->passMap_.begin())
{
Expand All @@ -217,15 +217,15 @@ Pass Orbit::getPassWithRevolutionNumber(const Integer& aRevolutionNumber, const

const auto closestPassMapIt = std::prev(lowerBoundMapIt);

// {5, 6, 9, 10} -> aRevolutionNumber=7 -> return 6
// {5, 6, 9, 10} -> revolutionNumber=7 -> return 6
// lowerBoundMapIt = 9, closestPassMapIt = 6

if ((aRevolutionNumber - closestPassMapIt->first) < (lowerBoundMapIt->first - aRevolutionNumber))
if ((revolutionNumber - closestPassMapIt->first) < (lowerBoundMapIt->first - revolutionNumber))

Check warning on line 223 in src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp

View check run for this annotation

Codecov / codecov/patch

src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit.cpp#L223

Added line #L223 was not covered by tests
{
return closestPassMapIt->second;
}

// {5, 6, 9, 10} -> aRevolutionNumber=8 -> return 9
// {5, 6, 9, 10} -> revolutionNumber=8 -> return 9
// lowerBoundMapIt = 9, closestPassMapIt = 6

return lowerBoundMapIt->second;
Expand Down Expand Up @@ -438,7 +438,7 @@ Array<Pass> Orbit::getPassesWithinInterval(const Interval& anInterval) const

if (pass.getInterval().contains(currentInstant))
{
revolutionNumber = passIt.first;
revolutionNumber = pass.getRevolutionNumber();
break;
}
}
Expand All @@ -451,11 +451,11 @@ Array<Pass> Orbit::getPassesWithinInterval(const Interval& anInterval) const

Array<Pass> passes = {};

while (currentInstant.isDefined() && (currentInstant <= anInterval.accessEnd()))
while (currentInstant <= anInterval.accessEnd())
{
const Pass pass = this->getPassWithRevolutionNumber(revolutionNumber);

passes.add(this->getPassWithRevolutionNumber(revolutionNumber));
passes.add(pass);

currentInstant = pass.accessInstantAtPassBreak();
revolutionNumber++;
Expand Down Expand Up @@ -1311,17 +1311,7 @@ Array<Pass> Orbit::ComputePassesWithModel(const orbit::Model& aModel, const Inte

Array<Pass> passes;

// [TBI] Make a parameter?
const Array<Instant> instants = anInterval.generateGrid(Duration::Minutes(5.0));

const Array<State> states = instants.map<State>(
[&aModel](const Instant& anInstant) -> State
{
return aModel.calculateStateAt(anInstant);
}
);

const Instant& epoch = aModel.getEpoch();
const Instant epoch = aModel.getEpoch();

const auto getZ = [&aModel, &epoch](const double& aDurationInSeconds) -> double
{
Expand All @@ -1339,84 +1329,97 @@ Array<Pass> Orbit::ComputePassesWithModel(const orbit::Model& aModel, const Inte
.z();
};

State const* previousStatePtr = nullptr;

Integer revolutionNumber = aModel.getRevolutionNumberAtEpoch();
State previousState = aModel.calculateStateAt(anInterval.getStart());
Duration stepDuration = Duration::Minutes(5.0);

Instant previousPassEndInstant =
(Real(states.accessFirst().getPosition().accessCoordinates().z()).isNear(0.0, epsilon))
? states.accessFirst().accessInstant()
: Instant::Undefined();
Instant previousPassEndInstant = (Real(previousState.getPosition().accessCoordinates().z()).isNear(0.0, epsilon))
? previousState.accessInstant()
: Instant::Undefined();
Instant northPointCrossing = Instant::Undefined();
Instant descendingNodeCrossing = Instant::Undefined();
Instant southPointCrossing = Instant::Undefined();
Instant passBreak = Instant::Undefined();

for (const auto& state : states)
while (true)
{
if (previousStatePtr != nullptr)
const Instant currentInstant = previousState.accessInstant() + stepDuration;

if (currentInstant >= anInterval.getEnd())
{
const Vector3d previousPositionCoordinates_ECI = previousStatePtr->getPosition().accessCoordinates();
const Vector3d previousVelocityCoordinates_ECI = previousStatePtr->getVelocity().accessCoordinates();
const Vector3d currentPositionCoordinates_ECI = state.getPosition().accessCoordinates();
const Vector3d currentVelocityCoordinates_ECI = state.getVelocity().accessCoordinates();

// North & South point crossings
if (((previousVelocityCoordinates_ECI.z() > 0.0) && (currentVelocityCoordinates_ECI.z() <= 0.0)) ||
((previousVelocityCoordinates_ECI.z() < 0.0) && (currentVelocityCoordinates_ECI.z() >= 0.0)))
{
if (currentPositionCoordinates_ECI.z() > 0.0)
{
northPointCrossing = Orbit::GetCrossingInstant(
epoch, previousStatePtr->accessInstant(), state.accessInstant(), getZDot
);
}
else
{
southPointCrossing = Orbit::GetCrossingInstant(
epoch, previousStatePtr->accessInstant(), state.accessInstant(), getZDot
);
}
}
break;
}

const State currentState = aModel.calculateStateAt(currentInstant);

// Descending node
if ((previousPositionCoordinates_ECI.z() > 0.0) && (currentPositionCoordinates_ECI.z() <= 0.0))
const Vector3d previousPositionCoordinates_ECI = previousState.getPosition().accessCoordinates();
const Vector3d previousVelocityCoordinates_ECI = previousState.getVelocity().accessCoordinates();
const Vector3d currentPositionCoordinates_ECI = currentState.getPosition().accessCoordinates();
const Vector3d currentVelocityCoordinates_ECI = currentState.getVelocity().accessCoordinates();

if (((previousVelocityCoordinates_ECI.z() > 0.0) && (currentVelocityCoordinates_ECI.z() <= 0.0)) ||
((previousVelocityCoordinates_ECI.z() < 0.0) && (currentVelocityCoordinates_ECI.z() >= 0.0)))
{
if (currentPositionCoordinates_ECI.z() > 0.0)
{
descendingNodeCrossing =
Orbit::GetCrossingInstant(epoch, previousStatePtr->accessInstant(), state.accessInstant(), getZ);
northPointCrossing = Orbit::GetCrossingInstant(
epoch, previousState.accessInstant(), currentState.accessInstant(), getZDot
);
}

// Pass break
if ((previousPositionCoordinates_ECI.z() < 0.0) && (currentPositionCoordinates_ECI.z() >= 0.0))
else
{
passBreak =
Orbit::GetCrossingInstant(epoch, previousStatePtr->accessInstant(), state.accessInstant(), getZ);
southPointCrossing = Orbit::GetCrossingInstant(
epoch, previousState.accessInstant(), currentState.accessInstant(), getZDot
);
}
}

if (passBreak.isDefined())
{
const Pass pass = {
revolutionNumber,
previousPassEndInstant,
northPointCrossing,
descendingNodeCrossing,
southPointCrossing,
passBreak,
};
if ((previousPositionCoordinates_ECI.z() > 0.0) && (currentPositionCoordinates_ECI.z() <= 0.0))
{
descendingNodeCrossing =
Orbit::GetCrossingInstant(epoch, previousState.accessInstant(), currentState.accessInstant(), getZ);
}

passes.add(pass);
if ((previousPositionCoordinates_ECI.z() < 0.0) && (currentPositionCoordinates_ECI.z() >= 0.0))
{
passBreak =
Orbit::GetCrossingInstant(epoch, previousState.accessInstant(), currentState.accessInstant(), getZ);
}

if (passBreak.isDefined())
{
const Pass pass = {
revolutionNumber,
previousPassEndInstant,
northPointCrossing,
descendingNodeCrossing,
southPointCrossing,
passBreak,
};

revolutionNumber++;
previousPassEndInstant = passBreak;
passes.add(pass);

southPointCrossing = Instant::Undefined();
northPointCrossing = Instant::Undefined();
descendingNodeCrossing = Instant::Undefined();
passBreak = Instant::Undefined();
revolutionNumber++;
previousPassEndInstant = passBreak;

southPointCrossing = Instant::Undefined();
northPointCrossing = Instant::Undefined();
descendingNodeCrossing = Instant::Undefined();
passBreak = Instant::Undefined();

if (pass.isComplete())
{
Array<Duration> durations = {
(pass.accessInstantAtNorthPoint() - pass.accessInstantAtAscendingNode()),
(pass.accessInstantAtDescendingNode() - pass.accessInstantAtNorthPoint()),
(pass.accessInstantAtSouthPoint() - pass.accessInstantAtDescendingNode()),
(pass.accessInstantAtPassBreak() - pass.accessInstantAtSouthPoint()),
};
stepDuration = *std::min_element(durations.begin(), durations.end()) / 2.0;
}
}

previousStatePtr = &state;
previousState = currentState;
}

// Add last partial pass
Expand Down
7 changes: 6 additions & 1 deletion src/OpenSpaceToolkit/Astrodynamics/Trajectory/Orbit/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,13 @@ Instant Pass::getEndInstant() const
return instantAtAscendingNode_;
}

const Interval Pass::getInterval() const
Interval Pass::getInterval() const
{
if (!this->isDefined())
{
throw ostk::core::error::runtime::Undefined("Pass");
}

if (type_ == Pass::Type::Complete)
{
return Interval::Closed(instantAtAscendingNode_, instantAtPassBreak_);
Expand Down

0 comments on commit bc0e81c

Please sign in to comment.