diff --git a/bindings/python/test/trajectory/test_state.py b/bindings/python/test/trajectory/test_state.py index 0582a8cc6..3a7c427cb 100644 --- a/bindings/python/test/trajectory/test_state.py +++ b/bindings/python/test/trajectory/test_state.py @@ -407,6 +407,82 @@ def test_from_dict_with_mass(self): assert isinstance(state, State) assert state.get_size() == 7 + @pytest.mark.parametrize( + ("data", "expected_length", "expected_frame"), + [ + ( + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "r_GCRF_x": 1.0, + "r_GCRF_y": 2.0, + "r_GCRF_z": 3.0, + "v_GCRF_x": 4.0, + "v_GCRF_y": 5.0, + "v_GCRF_z": 6.0, + }, + 6, + Frame.GCRF(), + ), + ( + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "r_ITRF_x": 1.0, + "r_ITRF_y": 2.0, + "r_ITRF_z": 3.0, + "v_ITRF_x": 4.0, + "v_ITRF_y": 5.0, + "v_ITRF_z": 6.0, + }, + 6, + Frame.ITRF(), + ), + ], + ) + def test_from_dict_cannonical_success( + self, data: dict, expected_length: int, expected_frame: Frame + ): + state: State = State.from_dict(data) + + assert state is not None + assert isinstance(state, State) + + assert state.get_size() == expected_length + assert state.get_frame() == expected_frame + + @pytest.mark.parametrize( + ("data", "expected_failure_message"), + [ + ( + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "r_GCRF_x": 1.0, + "r_GCRF_y": 2.0, + "r_GCRF_z": 3.0, + "v_GCRF_x": 4.0, + "v_GCRF_y": 5.0, + }, + "Invalid state data.", + ), + ( + { + "timestamp": datetime.now(timezone.utc).isoformat(), + "r_TEST_x": 1.0, + "r_TEST_y": 2.0, + "r_TEST_z": 3.0, + "v_TEST_x": 4.0, + "v_TEST_y": 5.0, + "v_TEST_z": 6.0, + }, + "No frame exists with name \\[TEST\\].", + ), + ], + ) + def test_from_dict_cannonical_failure( + self, data: dict, expected_failure_message: str + ): + with pytest.raises(ValueError, match=expected_failure_message): + State.from_dict(data) + def test_comparators(self, state: State): assert (state == state) is True assert (state != state) is False diff --git a/bindings/python/tools/python/ostk/astrodynamics/pytrajectory/pystate.py b/bindings/python/tools/python/ostk/astrodynamics/pytrajectory/pystate.py index 490f4bea8..3d7b24be9 100644 --- a/bindings/python/tools/python/ostk/astrodynamics/pytrajectory/pystate.py +++ b/bindings/python/tools/python/ostk/astrodynamics/pytrajectory/pystate.py @@ -1,7 +1,8 @@ # Apache License 2.0 - # Python-only State functionality +import re + import numpy as np from ostk.physics.coordinate import Frame @@ -17,6 +18,8 @@ AngularVelocity, ) +CANONICAL_FORMAT: str = r"(r|v)_(.*?)_(x|y|z)" + @staticmethod def custom_class_generator(frame: Frame, coordinate_subsets: list) -> type: @@ -45,14 +48,17 @@ def __init__(self, instant: Instant, coordinates: np.ndarray): def from_dict(data: dict) -> State: """ Create a State from a dictionary. + + Note: Implicit assumption that ECEF = ITRF, and ECI = GCRF. + The dictionary must contain the following: - 'timestamp': The timestamp of the state. - - 'rx'/'x_eci'/'x_ecef': The x-coordinate of the position. - - 'ry'/'y_eci'/'y_ecef': The y-coordinate of the position. - - 'rz'/'z_eci'/'z_ecef': The z-coordinate of the position. - - 'vx'/'vx_eci'/'vx_ecef': The x-coordinate of the velocity. - - 'vy'/'vy_eci'/'vy_ecef': The y-coordinate of the velocity. - - 'vz'/'vz_eci'/'vz_ecef': The z-coordinate of the velocity. + - 'r_ITRF_x'/'rx'/'rx_eci'/'rx_ecef': The x-coordinate of the position. + - 'r_ITRF_y'/'ry'/'ry_eci'/'ry_ecef': The y-coordinate of the position. + - 'r_ITRF_z'/'rz'/'rz_eci'/'rz_ecef': The z-coordinate of the position. + - 'v_ITRF_x'/'vx'/'vx_eci'/'vx_ecef': The x-coordinate of the velocity. + - 'v_ITRF_y'/'vy'/'vy_eci'/'vy_ecef': The y-coordinate of the velocity. + - 'v_ITRF_z'/'vz'/'vz_eci'/'vz_ecef': The z-coordinate of the velocity. - 'frame': The frame of the state. Required if 'rx', 'ry', 'rz', 'vx', 'vy', 'vz' are provided. - 'q_B_ECI_x': The x-coordinate of the quaternion. Optional. - 'q_B_ECI_y': The y-coordinate of the quaternion. Optional. @@ -65,7 +71,6 @@ def from_dict(data: dict) -> State: - 'cross_sectional_area'/'surface_area': The cross-sectional area. Optional. - 'mass': The mass. Optional. - Args: data (dict): The dictionary. @@ -100,6 +105,17 @@ def from_dict(data: dict) -> State: "vz", ] + # Replace non-standard position keys with canonical representation + if all(key in data.keys() for key in ("x_eci", "y_eci", "z_eci")): + data["rx_eci"] = data["x_eci"] + data["ry_eci"] = data["y_eci"] + data["rz_eci"] = data["z_eci"] + + if all(key in data.keys() for key in ("x_ecef", "y_ecef", "z_ecef")): + data["rx_ecef"] = data["x_ecef"] + data["ry_ecef"] = data["y_ecef"] + data["rz_ecef"] = data["z_ecef"] + frame: Frame coordinates: np.ndarray @@ -108,8 +124,21 @@ def from_dict(data: dict) -> State: CartesianVelocity.default(), ] - if all(column in data for column in eci_columns): - frame = Frame.GCRF() + match_groups: list[re.Match] = [ + re.match(CANONICAL_FORMAT, column) for column in data.keys() + ] + + if len(matches := [match for match in match_groups if match is not None]) == 6: + frame_name: str = matches[0].group(2) + try: + frame: Frame = Frame.with_name(frame_name) or getattr(Frame, frame_name)() + except Exception: + raise ValueError(f"No frame exists with name [{frame_name}].") + + coordinates = np.array([data[match.group(0)] for match in matches]) + + elif all(column in data for column in eci_columns): + frame: Frame = Frame.GCRF() coordinates = np.array( [ data["rx_eci"], @@ -120,6 +149,7 @@ def from_dict(data: dict) -> State: data["vz_eci"], ] ) + elif all(column in data for column in ecef_columns): frame = Frame.ITRF() coordinates = np.array( @@ -132,6 +162,7 @@ def from_dict(data: dict) -> State: data["vz_ecef"], ] ) + elif all(column in data for column in generic_columns): if "frame" not in data: raise ValueError("Frame must be provided for generic columns.")