Skip to content

Commit

Permalink
Add inputs of post_injection_transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed May 3, 2024
1 parent b3a97d9 commit 1321926
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ Note, `qusi` expects the label functions to take in a `Path` object as input, ev
Now we're going to join the various functions we've just defined into `LightCurveObservationCollection`s. For the case of positive train light curves, this looks like:

```python
positive_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
positive_train_light_curve_collection = LightCurveObservationCollection.new()
```

This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment.
Expand All @@ -70,17 +67,10 @@ Finally, we build the dataset `qusi` uses to train the network. First, we'll tak

```python
def get_transit_train_dataset():
positive_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_negative_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
train_light_curve_dataset = LightCurveDataset.new(
standard_light_curve_collections=[positive_train_light_curve_collection,
negative_train_light_curve_collection])
positive_train_light_curve_collection = LightCurveObservationCollection.new()
negative_train_light_curve_collection = LightCurveObservationCollection.new()
train_light_curve_dataset = LightCurveDataset.new(light_curve_collections=[positive_train_light_curve_collection,
negative_train_light_curve_collection])
return train_light_curve_dataset
```

Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/crafting_standard_datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ Then, were we specify the construction of our dataset, we'll add an additional i

```python
train_light_curve_dataset = LightCurveObservationDataset.new(
standard_light_curve_collections=[positive_train_light_curve_collection,
negative_train_light_curve_collection]
post_injection_transform=partial(default_light_curve_post_injection_transform, length=4000)
light_curve_collections=[positive_train_light_curve_collection,
negative_train_light_curve_collection])
post_injection_transform = partial(default_light_curve_post_injection_transform, length=4000)
)
```

Expand Down
4 changes: 2 additions & 2 deletions examples/transit_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ def get_transit_finite_test_dataset():
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new(
standard_light_curve_collections=[positive_test_light_curve_collection,
negative_test_light_curve_collection])
light_curve_collections=[positive_test_light_curve_collection,
negative_test_light_curve_collection])
return test_light_curve_dataset
14 changes: 12 additions & 2 deletions src/qusi/internal/finite_standard_light_curve_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
Expand All @@ -18,13 +20,21 @@ class FiniteStandardLightCurveDataset(Dataset):
collection_start_indexes: list[int]

@classmethod
def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self:
def new(
cls,
light_curve_collections: list[LightCurveCollection],
*,
post_injection_transform: Callable[[Any], Any] | None = None,
) -> Self:
"""
Creates a new `FiniteStandardLightCurveDataset`.
:param light_curve_collections: The light curve collections to include in the dataset.
:param post_injection_transform: Transforms to the data to occur after injection.
:return: The dataset.
"""
if post_injection_transform is None:
post_injection_transform = partial(default_light_curve_post_injection_transform, length=2500)
length = 0
collection_start_indexes: list[int] = []
for light_curve_collection in light_curve_collections:
Expand All @@ -33,7 +43,7 @@ def new(cls, light_curve_collections: list[LightCurveCollection]) -> Self:
length += standard_light_curve_collection_length
instance = cls(
standard_light_curve_collections=light_curve_collections,
post_injection_transform=partial(default_light_curve_post_injection_transform, length=2500),
post_injection_transform=post_injection_transform,
length=length,
collection_start_indexes=collection_start_indexes,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from typing import Any, Callable
Expand All @@ -18,16 +20,30 @@ class FiniteStandardLightCurveObservationDataset(Dataset):
collection_start_indexes: list[int]

@classmethod
def new(cls, standard_light_curve_collections: list[LightCurveObservationCollection]) -> Self:
def new(
cls,
light_curve_collections: list[LightCurveObservationCollection],
*,
post_injection_transform: Callable[[Any], Any] | None = None,
) -> Self:
"""
Creates a new `FiniteStandardLightCurveObservationDataset`.
:param light_curve_collections: The light curve observation collections to include in the dataset.
:param post_injection_transform: Transforms to the data to occur after injection.
:return: The dataset.
"""
if post_injection_transform is None:
post_injection_transform = partial(default_light_curve_observation_post_injection_transform, length=2500)
length = 0
collection_start_indexes: list[int] = []
for standard_light_curve_collection in standard_light_curve_collections:
for standard_light_curve_collection in light_curve_collections:
standard_light_curve_collection_length = len(list(standard_light_curve_collection.path_getter.get_paths()))
collection_start_indexes.append(length)
length += standard_light_curve_collection_length
instance = cls(
standard_light_curve_collections=standard_light_curve_collections,
post_injection_transform=partial(default_light_curve_observation_post_injection_transform, length=2500),
standard_light_curve_collections=light_curve_collections,
post_injection_transform=post_injection_transform,
length=length,
collection_start_indexes=collection_start_indexes,
)
Expand Down
8 changes: 4 additions & 4 deletions src/qusi/internal/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class LightCurveDataset(IterableDataset):
def __init__(
self,
standard_light_curve_collections: list[LightCurveObservationCollection],
*,
injectee_light_curve_collections: list[LightCurveObservationCollection],
injectable_light_curve_collections: list[LightCurveObservationCollection],
post_injection_transform: Callable[[Any], Any],
Expand Down Expand Up @@ -134,8 +135,8 @@ def __iter__(self):
@classmethod
def new(
cls,
*,
standard_light_curve_collections: list[LightCurveObservationCollection] | None = None,
*,
injectee_light_curve_collections: list[LightCurveObservationCollection] | None = None,
injectable_light_curve_collections: list[LightCurveObservationCollection] | None = None,
post_injection_transform: Callable[[Any], Any] | None = None,
Expand Down Expand Up @@ -276,9 +277,8 @@ def __iter__(self):
break


def default_light_curve_observation_post_injection_transform(
x: LightCurveObservation, *, length: int
) -> (Tensor, Tensor):
def default_light_curve_observation_post_injection_transform(x: LightCurveObservation, *, length: int
) -> (Tensor, Tensor):
x = remove_nan_flux_data_points_from_light_curve_observation(x)
x = randomly_roll_light_curve_observation(x)
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
Expand Down

0 comments on commit 1321926

Please sign in to comment.