From 1321926a5a8c72839602fdc0b9d8799dadd324be Mon Sep 17 00:00:00 2001 From: golmschenk Date: Fri, 3 May 2024 17:54:13 -0400 Subject: [PATCH] Add inputs of post_injection_transforms --- ...sit_identification_dataset_construction.md | 20 ++++------------ .../tutorials/crafting_standard_datasets.md | 6 ++--- examples/transit_dataset.py | 4 ++-- .../finite_standard_light_curve_dataset.py | 14 +++++++++-- ...tandard_light_curve_observation_dataset.py | 24 +++++++++++++++---- src/qusi/internal/light_curve_dataset.py | 8 +++---- 6 files changed, 46 insertions(+), 30 deletions(-) diff --git a/docs/source/tutorials/basic_transit_identification_dataset_construction.md b/docs/source/tutorials/basic_transit_identification_dataset_construction.md index 1a61e93d..4b9984cf 100644 --- a/docs/source/tutorials/basic_transit_identification_dataset_construction.md +++ b/docs/source/tutorials/basic_transit_identification_dataset_construction.md @@ -56,10 +56,7 @@ Note, `qusi` expects the label functions to take in a `Path` object as input, ev Now we're going to join the various functions we've just defined into `LightCurveObservationCollection`s. For the case of positive train light curves, this looks like: ```python -positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) +positive_train_light_curve_collection = LightCurveObservationCollection.new() ``` This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment. @@ -70,17 +67,10 @@ Finally, we build the dataset `qusi` uses to train the network. First, we'll tak ```python def get_transit_train_dataset(): - positive_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_positive_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=positive_label_function) - negative_train_light_curve_collection = LightCurveObservationCollection.new( - get_paths_function=get_negative_train_paths, - load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, - load_label_from_path_function=negative_label_function) - train_light_curve_dataset = LightCurveDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection]) + positive_train_light_curve_collection = LightCurveObservationCollection.new() + negative_train_light_curve_collection = LightCurveObservationCollection.new() + train_light_curve_dataset = LightCurveDataset.new(light_curve_collections=[positive_train_light_curve_collection, + negative_train_light_curve_collection]) return train_light_curve_dataset ``` diff --git a/docs/source/tutorials/crafting_standard_datasets.md b/docs/source/tutorials/crafting_standard_datasets.md index 38cfcd8e..13578afd 100644 --- a/docs/source/tutorials/crafting_standard_datasets.md +++ b/docs/source/tutorials/crafting_standard_datasets.md @@ -19,9 +19,9 @@ Then, were we specify the construction of our dataset, we'll add an additional i ```python train_light_curve_dataset = LightCurveObservationDataset.new( - standard_light_curve_collections=[positive_train_light_curve_collection, - negative_train_light_curve_collection] - post_injection_transform=partial(default_light_curve_post_injection_transform, length=4000) + light_curve_collections=[positive_train_light_curve_collection, + negative_train_light_curve_collection]) +post_injection_transform = partial(default_light_curve_post_injection_transform, length=4000) ) ``` diff --git a/examples/transit_dataset.py b/examples/transit_dataset.py index 02db37a0..918f15de 100644 --- a/examples/transit_dataset.py +++ b/examples/transit_dataset.py @@ -81,6 +81,6 @@ def get_transit_finite_test_dataset(): load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path, load_label_from_path_function=negative_label_function) test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new( - standard_light_curve_collections=[positive_test_light_curve_collection, - negative_test_light_curve_collection]) + light_curve_collections=[positive_test_light_curve_collection, + negative_test_light_curve_collection]) return test_light_curve_dataset diff --git a/src/qusi/internal/finite_standard_light_curve_dataset.py b/src/qusi/internal/finite_standard_light_curve_dataset.py index c3f7f6cb..6c84e9ec 100644 --- a/src/qusi/internal/finite_standard_light_curve_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from functools import partial from typing import Any, Callable @@ -18,13 +20,21 @@ class FiniteStandardLightCurveDataset(Dataset): collection_start_indexes: list[int] @classmethod - def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: + def new( + cls, + light_curve_collections: list[LightCurveCollection], + *, + post_injection_transform: Callable[[Any], Any] | None = None, + ) -> Self: """ Creates a new `FiniteStandardLightCurveDataset`. :param light_curve_collections: The light curve collections to include in the dataset. + :param post_injection_transform: Transforms to the data to occur after injection. :return: The dataset. """ + if post_injection_transform is None: + post_injection_transform = partial(default_light_curve_post_injection_transform, length=2500) length = 0 collection_start_indexes: list[int] = [] for light_curve_collection in light_curve_collections: @@ -33,7 +43,7 @@ def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self: length += standard_light_curve_collection_length instance = cls( standard_light_curve_collections=light_curve_collections, - post_injection_transform=partial(default_light_curve_post_injection_transform, length=2500), + post_injection_transform=post_injection_transform, length=length, collection_start_indexes=collection_start_indexes, ) diff --git a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py index c64f4b66..21a44a14 100644 --- a/src/qusi/internal/finite_standard_light_curve_observation_dataset.py +++ b/src/qusi/internal/finite_standard_light_curve_observation_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass from functools import partial from typing import Any, Callable @@ -18,16 +20,30 @@ class FiniteStandardLightCurveObservationDataset(Dataset): collection_start_indexes: list[int] @classmethod - def new(cls, standard_light_curve_collections: list[LightCurveObservationCollection]) -> Self: + def new( + cls, + light_curve_collections: list[LightCurveObservationCollection], + *, + post_injection_transform: Callable[[Any], Any] | None = None, + ) -> Self: + """ + Creates a new `FiniteStandardLightCurveObservationDataset`. + + :param light_curve_collections: The light curve observation collections to include in the dataset. + :param post_injection_transform: Transforms to the data to occur after injection. + :return: The dataset. + """ + if post_injection_transform is None: + post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=2500) length = 0 collection_start_indexes: list[int] = [] - for standard_light_curve_collection in standard_light_curve_collections: + for standard_light_curve_collection in light_curve_collections: standard_light_curve_collection_length = len(list(standard_light_curve_collection.path_getter.get_paths())) collection_start_indexes.append(length) length += standard_light_curve_collection_length instance = cls( - standard_light_curve_collections=standard_light_curve_collections, - post_injection_transform=partial(default_light_curve_observation_post_injection_transform, length=2500), + standard_light_curve_collections=light_curve_collections, + post_injection_transform=post_injection_transform, length=length, collection_start_indexes=collection_start_indexes, ) diff --git a/src/qusi/internal/light_curve_dataset.py b/src/qusi/internal/light_curve_dataset.py index b12818e9..0c8112e0 100644 --- a/src/qusi/internal/light_curve_dataset.py +++ b/src/qusi/internal/light_curve_dataset.py @@ -49,6 +49,7 @@ class LightCurveDataset(IterableDataset): def __init__( self, standard_light_curve_collections: list[LightCurveObservationCollection], + *, injectee_light_curve_collections: list[LightCurveObservationCollection], injectable_light_curve_collections: list[LightCurveObservationCollection], post_injection_transform: Callable[[Any], Any], @@ -134,8 +135,8 @@ def __iter__(self): @classmethod def new( cls, - *, standard_light_curve_collections: list[LightCurveObservationCollection] | None = None, + *, injectee_light_curve_collections: list[LightCurveObservationCollection] | None = None, injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None, post_injection_transform: Callable[[Any], Any] | None = None, @@ -276,9 +277,8 @@ def __iter__(self): break -def default_light_curve_observation_post_injection_transform( - x: LightCurveObservation, *, length: int -) -> (Tensor, Tensor): +def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, *, length: int + ) -> (Tensor, Tensor): x = remove_nan_flux_data_points_from_light_curve_observation(x) x = randomly_roll_light_curve_observation(x) x = from_light_curve_observation_to_fluxes_array_and_label_array(x)