diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index baabe69..ed0ffe7 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -35,6 +35,7 @@ In the previous section, we only changed the length of that the uniform lengthen ```python def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int, randomize: bool = True) -> (Tensor, Tensor): + x = remove_infinite_flux_data_points_from_light_curve_observation(x) x = remove_nan_flux_data_points_from_light_curve_observation(x) if randomize: x = randomly_roll_light_curve_observation(x) diff --git a/src/qusi/internal/light_curve.py b/src/qusi/internal/light_curve.py index d866dda..5a9bf7b 100644 --- a/src/qusi/internal/light_curve.py +++ b/src/qusi/internal/light_curve.py @@ -48,6 +48,21 @@ def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> Lig return light_curve +def remove_infinite_flux_data_points_from_light_curve(light_curve: LightCurve) -> LightCurve: + """ + Removes infinite values from a light curve. If there is an infinite value in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve: The light curve. + :return: The light curve with infinite values removed. + """ + light_curve = deepcopy(light_curve) + infinite_flux_indexes = np.isinf(light_curve.fluxes) + light_curve.fluxes = light_curve.fluxes[~infinite_flux_indexes] + light_curve.times = light_curve.times[~infinite_flux_indexes] + return light_curve + + def randomly_roll_light_curve(light_curve: LightCurve) -> LightCurve: """ Randomly rolls a light curve. That is, a random position in the light curve is chosen, the light curve diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index 3ce563a..dafc74b 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -25,12 +25,13 @@ from qusi.internal.light_curve import ( LightCurve, randomly_roll_light_curve, - remove_nan_flux_data_points_from_light_curve, + remove_nan_flux_data_points_from_light_curve, remove_infinite_flux_data_points_from_light_curve, ) from qusi.internal.light_curve_observation import ( LightCurveObservation, randomly_roll_light_curve_observation, remove_nan_flux_data_points_from_light_curve_observation, + remove_infinite_flux_data_points_from_light_curve_observation, ) from qusi.internal.light_curve_transforms import ( from_light_curve_observation_to_fluxes_array_and_label_array, @@ -357,6 +358,7 @@ def default_light_curve_observation_post_injection_transform( :param randomize: Whether to have randomization in the transforms. :return: The transformed light curve observation. """ + x = remove_infinite_flux_data_points_from_light_curve_observation(x) x = remove_nan_flux_data_points_from_light_curve_observation(x) if randomize: x = randomly_roll_light_curve_observation(x) @@ -382,6 +384,7 @@ def default_light_curve_post_injection_transform( :param randomize: Whether to have randomization in the transforms. :return: The transformed light curve. """ + x = remove_infinite_flux_data_points_from_light_curve(x) x = remove_nan_flux_data_points_from_light_curve(x) if randomize: x = randomly_roll_light_curve(x) diff --git a/src/qusi/internal/light_curve_observation.py b/src/qusi/internal/light_curve_observation.py index aa1cc86..53f12ee 100644 --- a/src/qusi/internal/light_curve_observation.py +++ b/src/qusi/internal/light_curve_observation.py @@ -3,7 +3,9 @@ from typing_extensions import Self -from qusi.internal.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve import (LightCurve, randomly_roll_light_curve, + remove_nan_flux_data_points_from_light_curve, + remove_infinite_flux_data_points_from_light_curve) @dataclass @@ -48,6 +50,23 @@ def remove_nan_flux_data_points_from_light_curve_observation( return light_curve_observation +def remove_infinite_flux_data_points_from_light_curve_observation( + light_curve_observation: LightCurveObservation, +) -> LightCurveObservation: + """ + Removes the inf values from a light curve in a light curve observation. If there is an inf in either the times or the + fluxes, both corresponding values are removed. + + :param light_curve_observation: The light curve observation. + :return: The light curve observation with inf values removed. + """ + light_curve_observation = deepcopy(light_curve_observation) + light_curve_observation.light_curve = remove_infinite_flux_data_points_from_light_curve( + light_curve_observation.light_curve + ) + return light_curve_observation + + def randomly_roll_light_curve_observation(light_curve_observation: LightCurveObservation) -> LightCurveObservation: """ Randomly rolls a light curve observation. That is, a random position in the light curve is chosen, the light curve diff --git a/src/qusi/transform.py b/src/qusi/transform.py index ba3b863..dc5a3da 100644 --- a/src/qusi/transform.py +++ b/src/qusi/transform.py @@ -1,11 +1,12 @@ """ Data transform related public interface. """ -from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve +from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve, \ + remove_infinite_flux_data_points_from_light_curve from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \ default_light_curve_observation_post_injection_transform from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \ - randomly_roll_light_curve_observation + randomly_roll_light_curve_observation, remove_infinite_flux_data_points_from_light_curve_observation from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \ pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score @@ -20,4 +21,6 @@ 'randomly_roll_light_curve_observation', 'remove_nan_flux_data_points_from_light_curve', 'remove_nan_flux_data_points_from_light_curve_observation', + 'remove_infinite_flux_data_points_from_light_curve', + 'remove_infinite_flux_data_points_from_light_curve_observation', ] diff --git a/tests/unit_tests/test_transform.py b/tests/unit_tests/test_transform.py new file mode 100644 index 0000000..aff7130 --- /dev/null +++ b/tests/unit_tests/test_transform.py @@ -0,0 +1,17 @@ +import numpy as np + +from qusi.internal.light_curve import LightCurve, remove_infinite_flux_data_points_from_light_curve + + +def test_remove_infinite_flux_data_points_from_light_curve(): + times = np.array([0.0, 1.0, 2.0]) + fluxes = np.array([0.0, np.inf, 20.0]) + light_curve = LightCurve.new( + times=times, + fluxes=fluxes, + ) + updated_light_curve = remove_infinite_flux_data_points_from_light_curve(light_curve=light_curve) + expected_times = np.array([0.0, 2.0]) + expected_fluxes = np.array([0.0, 20.0]) + assert np.array_equal(updated_light_curve.times, expected_times) + assert np.array_equal(updated_light_curve.fluxes, expected_fluxes)