Skip to content

Commit

Permalink
Correct injection bug where only injectee fluxes were used
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 30, 2024
1 parent 4c75d90 commit f80fc40
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@
from qusi.internal.light_curve_collection import LightCurveObservationCollection


class OutOfBoundsInjectionHandlingMethod(Enum):
"""
An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal.
"""

ERROR = "error"
REPEAT_SIGNAL = "repeat_signal"
RANDOM_INJECTION_LOCATION = "random_inject_location"


class LightCurveDataset(IterableDataset):
"""
A dataset of light curves. Includes cases where light curves can be injected into one another.
Expand Down Expand Up @@ -182,6 +192,8 @@ def new(
def inject_light_curve(
injectee_observation: LightCurveObservation,
injectable_observation: LightCurveObservation,
*,
out_of_bounds_injection_handling_method=OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION,
) -> LightCurveObservation:
(
fluxes_with_injected_signal,
Expand All @@ -192,12 +204,12 @@ def inject_light_curve(
light_curve_fluxes=injectee_observation.light_curve.fluxes,
signal_times=injectable_observation.light_curve.times,
signal_magnifications=injectable_observation.light_curve.fluxes,
out_of_bounds_injection_handling_method=OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION,
out_of_bounds_injection_handling_method=out_of_bounds_injection_handling_method,
baseline_flux_estimation_method=BaselineFluxEstimationMethod.MEDIAN,
)
injected_light_curve = LightCurve.new(
times=injectee_observation.light_curve.times,
fluxes=injectee_observation.light_curve.fluxes,
fluxes=fluxes_with_injected_signal,
)
injected_observation = LightCurveObservation.new(
light_curve=injected_light_curve, label=injectable_observation.label
Expand Down Expand Up @@ -337,16 +349,6 @@ def default_light_curve_post_injection_transform(
return x


class OutOfBoundsInjectionHandlingMethod(Enum):
"""
An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal.
"""

ERROR = "error"
REPEAT_SIGNAL = "repeat_signal"
RANDOM_INJECTION_LOCATION = "random_inject_location"


class BaselineFluxEstimationMethod(Enum):
"""
An enum of to designate the type of baseline flux estimation method to use during training.
Expand Down

0 comments on commit f80fc40

Please sign in to comment.