Skip to content

Commit

Permalink
feat: add more fields for pystate.from_dict (#453)
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwa2710 authored Oct 31, 2024
1 parent 88e4315 commit 75f2689
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 10 deletions.
76 changes: 76 additions & 0 deletions bindings/python/test/trajectory/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,82 @@ def test_from_dict_with_mass(self):
assert isinstance(state, State)
assert state.get_size() == 7

@pytest.mark.parametrize(
("data", "expected_length", "expected_frame"),
[
(
{
"timestamp": datetime.now(timezone.utc).isoformat(),
"r_GCRF_x": 1.0,
"r_GCRF_y": 2.0,
"r_GCRF_z": 3.0,
"v_GCRF_x": 4.0,
"v_GCRF_y": 5.0,
"v_GCRF_z": 6.0,
},
6,
Frame.GCRF(),
),
(
{
"timestamp": datetime.now(timezone.utc).isoformat(),
"r_ITRF_x": 1.0,
"r_ITRF_y": 2.0,
"r_ITRF_z": 3.0,
"v_ITRF_x": 4.0,
"v_ITRF_y": 5.0,
"v_ITRF_z": 6.0,
},
6,
Frame.ITRF(),
),
],
)
def test_from_dict_cannonical_success(
self, data: dict, expected_length: int, expected_frame: Frame
):
state: State = State.from_dict(data)

assert state is not None
assert isinstance(state, State)

assert state.get_size() == expected_length
assert state.get_frame() == expected_frame

@pytest.mark.parametrize(
("data", "expected_failure_message"),
[
(
{
"timestamp": datetime.now(timezone.utc).isoformat(),
"r_GCRF_x": 1.0,
"r_GCRF_y": 2.0,
"r_GCRF_z": 3.0,
"v_GCRF_x": 4.0,
"v_GCRF_y": 5.0,
},
"Invalid state data.",
),
(
{
"timestamp": datetime.now(timezone.utc).isoformat(),
"r_TEST_x": 1.0,
"r_TEST_y": 2.0,
"r_TEST_z": 3.0,
"v_TEST_x": 4.0,
"v_TEST_y": 5.0,
"v_TEST_z": 6.0,
},
"No frame exists with name \\[TEST\\].",
),
],
)
def test_from_dict_cannonical_failure(
self, data: dict, expected_failure_message: str
):
with pytest.raises(ValueError, match=expected_failure_message):
State.from_dict(data)

def test_comparators(self, state: State):
assert (state == state) is True
assert (state != state) is False
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Apache License 2.0


# Python-only State functionality
import re

import numpy as np

from ostk.physics.coordinate import Frame
Expand All @@ -17,6 +18,8 @@
AngularVelocity,
)

CANONICAL_FORMAT: str = r"(r|v)_(.*?)_(x|y|z)"


@staticmethod
def custom_class_generator(frame: Frame, coordinate_subsets: list) -> type:
Expand Down Expand Up @@ -45,14 +48,17 @@ def __init__(self, instant: Instant, coordinates: np.ndarray):
def from_dict(data: dict) -> State:
"""
Create a State from a dictionary.
Note: Implicit assumption that ECEF = ITRF, and ECI = GCRF.
The dictionary must contain the following:
- 'timestamp': The timestamp of the state.
- 'rx'/'x_eci'/'x_ecef': The x-coordinate of the position.
- 'ry'/'y_eci'/'y_ecef': The y-coordinate of the position.
- 'rz'/'z_eci'/'z_ecef': The z-coordinate of the position.
- 'vx'/'vx_eci'/'vx_ecef': The x-coordinate of the velocity.
- 'vy'/'vy_eci'/'vy_ecef': The y-coordinate of the velocity.
- 'vz'/'vz_eci'/'vz_ecef': The z-coordinate of the velocity.
- 'r_ITRF_x'/'rx'/'rx_eci'/'rx_ecef': The x-coordinate of the position.
- 'r_ITRF_y'/'ry'/'ry_eci'/'ry_ecef': The y-coordinate of the position.
- 'r_ITRF_z'/'rz'/'rz_eci'/'rz_ecef': The z-coordinate of the position.
- 'v_ITRF_x'/'vx'/'vx_eci'/'vx_ecef': The x-coordinate of the velocity.
- 'v_ITRF_y'/'vy'/'vy_eci'/'vy_ecef': The y-coordinate of the velocity.
- 'v_ITRF_z'/'vz'/'vz_eci'/'vz_ecef': The z-coordinate of the velocity.
- 'frame': The frame of the state. Required if 'rx', 'ry', 'rz', 'vx', 'vy', 'vz' are provided.
- 'q_B_ECI_x': The x-coordinate of the quaternion. Optional.
- 'q_B_ECI_y': The y-coordinate of the quaternion. Optional.
Expand All @@ -65,7 +71,6 @@ def from_dict(data: dict) -> State:
- 'cross_sectional_area'/'surface_area': The cross-sectional area. Optional.
- 'mass': The mass. Optional.
Args:
data (dict): The dictionary.
Expand Down Expand Up @@ -100,6 +105,17 @@ def from_dict(data: dict) -> State:
"vz",
]

# Replace non-standard position keys with canonical representation
if all(key in data.keys() for key in ("x_eci", "y_eci", "z_eci")):
data["rx_eci"] = data["x_eci"]
data["ry_eci"] = data["y_eci"]
data["rz_eci"] = data["z_eci"]

if all(key in data.keys() for key in ("x_ecef", "y_ecef", "z_ecef")):
data["rx_ecef"] = data["x_ecef"]
data["ry_ecef"] = data["y_ecef"]
data["rz_ecef"] = data["z_ecef"]

frame: Frame
coordinates: np.ndarray

Expand All @@ -108,8 +124,21 @@ def from_dict(data: dict) -> State:
CartesianVelocity.default(),
]

if all(column in data for column in eci_columns):
frame = Frame.GCRF()
match_groups: list[re.Match] = [
re.match(CANONICAL_FORMAT, column) for column in data.keys()
]

if len(matches := [match for match in match_groups if match is not None]) == 6:
frame_name: str = matches[0].group(2)
try:
frame: Frame = Frame.with_name(frame_name) or getattr(Frame, frame_name)()
except Exception:
raise ValueError(f"No frame exists with name [{frame_name}].")

coordinates = np.array([data[match.group(0)] for match in matches])

elif all(column in data for column in eci_columns):
frame: Frame = Frame.GCRF()
coordinates = np.array(
[
data["rx_eci"],
Expand All @@ -120,6 +149,7 @@ def from_dict(data: dict) -> State:
data["vz_eci"],
]
)

elif all(column in data for column in ecef_columns):
frame = Frame.ITRF()
coordinates = np.array(
Expand All @@ -132,6 +162,7 @@ def from_dict(data: dict) -> State:
data["vz_ecef"],
]
)

elif all(column in data for column in generic_columns):
if "frame" not in data:
raise ValueError("Frame must be provided for generic columns.")
Expand Down

0 comments on commit 75f2689

Please sign in to comment.