Skip to content

Commit

Permalink
Merge pull request #72 from golmschenk/infinite_data_points
Browse files Browse the repository at this point in the history
Infinite data points removal
  • Loading branch information
golmschenk authored Aug 27, 2024
2 parents da577e2 + 4ce0da1 commit 10a1275
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/source/tutorials/crafting_standard_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ In the previous section, we only changed the length of that the uniform lengthen

```python
def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, length: int, randomize: bool = True) -> (Tensor, Tensor):
x = remove_infinite_flux_data_points_from_light_curve_observation(x)
x = remove_nan_flux_data_points_from_light_curve_observation(x)
if randomize:
x = randomly_roll_light_curve_observation(x)
Expand Down
15 changes: 15 additions & 0 deletions src/qusi/internal/light_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,21 @@ def remove_nan_flux_data_points_from_light_curve(light_curve: LightCurve) -> Lig
return light_curve


def remove_infinite_flux_data_points_from_light_curve(light_curve: LightCurve) -> LightCurve:
"""
Removes infinite values from a light curve. If there is an infinite value in either the times or the
fluxes, both corresponding values are removed.
:param light_curve: The light curve.
:return: The light curve with infinite values removed.
"""
light_curve = deepcopy(light_curve)
infinite_flux_indexes = np.isinf(light_curve.fluxes)
light_curve.fluxes = light_curve.fluxes[~infinite_flux_indexes]
light_curve.times = light_curve.times[~infinite_flux_indexes]
return light_curve


def randomly_roll_light_curve(light_curve: LightCurve) -> LightCurve:
"""
Randomly rolls a light curve. That is, a random position in the light curve is chosen, the light curve
Expand Down
5 changes: 4 additions & 1 deletion src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from qusi.internal.light_curve import (
LightCurve,
randomly_roll_light_curve,
remove_nan_flux_data_points_from_light_curve,
remove_nan_flux_data_points_from_light_curve, remove_infinite_flux_data_points_from_light_curve,
)
from qusi.internal.light_curve_observation import (
LightCurveObservation,
randomly_roll_light_curve_observation,
remove_nan_flux_data_points_from_light_curve_observation,
remove_infinite_flux_data_points_from_light_curve_observation,
)
from qusi.internal.light_curve_transforms import (
from_light_curve_observation_to_fluxes_array_and_label_array,
Expand Down Expand Up @@ -357,6 +358,7 @@ def default_light_curve_observation_post_injection_transform(
:param randomize: Whether to have randomization in the transforms.
:return: The transformed light curve observation.
"""
x = remove_infinite_flux_data_points_from_light_curve_observation(x)
x = remove_nan_flux_data_points_from_light_curve_observation(x)
if randomize:
x = randomly_roll_light_curve_observation(x)
Expand All @@ -382,6 +384,7 @@ def default_light_curve_post_injection_transform(
:param randomize: Whether to have randomization in the transforms.
:return: The transformed light curve.
"""
x = remove_infinite_flux_data_points_from_light_curve(x)
x = remove_nan_flux_data_points_from_light_curve(x)
if randomize:
x = randomly_roll_light_curve(x)
Expand Down
21 changes: 20 additions & 1 deletion src/qusi/internal/light_curve_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from typing_extensions import Self

from qusi.internal.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve
from qusi.internal.light_curve import (LightCurve, randomly_roll_light_curve,
remove_nan_flux_data_points_from_light_curve,
remove_infinite_flux_data_points_from_light_curve)


@dataclass
Expand Down Expand Up @@ -48,6 +50,23 @@ def remove_nan_flux_data_points_from_light_curve_observation(
return light_curve_observation


def remove_infinite_flux_data_points_from_light_curve_observation(
light_curve_observation: LightCurveObservation,
) -> LightCurveObservation:
"""
Removes the inf values from a light curve in a light curve observation. If there is an inf in either the times or the
fluxes, both corresponding values are removed.
:param light_curve_observation: The light curve observation.
:return: The light curve observation with inf values removed.
"""
light_curve_observation = deepcopy(light_curve_observation)
light_curve_observation.light_curve = remove_infinite_flux_data_points_from_light_curve(
light_curve_observation.light_curve
)
return light_curve_observation


def randomly_roll_light_curve_observation(light_curve_observation: LightCurveObservation) -> LightCurveObservation:
"""
Randomly rolls a light curve observation. That is, a random position in the light curve is chosen, the light curve
Expand Down
7 changes: 5 additions & 2 deletions src/qusi/transform.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
Data transform related public interface.
"""
from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve
from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve, \
remove_infinite_flux_data_points_from_light_curve
from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \
default_light_curve_observation_post_injection_transform
from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \
randomly_roll_light_curve_observation
randomly_roll_light_curve_observation, remove_infinite_flux_data_points_from_light_curve_observation
from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \
pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score

Expand All @@ -20,4 +21,6 @@
'randomly_roll_light_curve_observation',
'remove_nan_flux_data_points_from_light_curve',
'remove_nan_flux_data_points_from_light_curve_observation',
'remove_infinite_flux_data_points_from_light_curve',
'remove_infinite_flux_data_points_from_light_curve_observation',
]
17 changes: 17 additions & 0 deletions tests/unit_tests/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import numpy as np

from qusi.internal.light_curve import LightCurve, remove_infinite_flux_data_points_from_light_curve


def test_remove_infinite_flux_data_points_from_light_curve():
times = np.array([0.0, 1.0, 2.0])
fluxes = np.array([0.0, np.inf, 20.0])
light_curve = LightCurve.new(
times=times,
fluxes=fluxes,
)
updated_light_curve = remove_infinite_flux_data_points_from_light_curve(light_curve=light_curve)
expected_times = np.array([0.0, 2.0])
expected_fluxes = np.array([0.0, 20.0])
assert np.array_equal(updated_light_curve.times, expected_times)
assert np.array_equal(updated_light_curve.fluxes, expected_fluxes)

0 comments on commit 10a1275

Please sign in to comment.