Skip to content

Commit

Permalink
Rename LabeledLightCurveCollection to LightCurveObservationCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Apr 29, 2024
1 parent 2235440 commit 3e1c7c7
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,28 @@ Note, `qusi` expects the label functions to take in a `Path` object as input, ev

## Creating a light curve collection

Now we're going to join the various functions we've just defined into `LabeledLightCurveCollection`s. For the case of positive train light curves, this looks like:
Now we're going to join the various functions we've just defined into `LightCurveObservationCollection`s. For the case of positive train light curves, this looks like:

```python
positive_train_light_curve_collection = LabeledLightCurveCollection.new(
positive_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
```

This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LabeledLightCurveCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment.
This defines a collection of labeled light curves where `qusi` knows how to obtain the paths, how to load the times and fluxes of the light curves, and how to load the labels. This `LightCurveObservationCollection.new(...` function takes in the three pieces we just built earlier. Note that you pass in the functions themselves, not the output of the functions. So for the `get_paths_function` parameter, we pass `get_positive_train_paths`, not `get_positive_train_paths()` (notice the difference in parenthesis). `qusi` will call these functions internally. However, the above bit of code is not by itself in `examples/transit_dataset.py` as the rest of the code in this tutorial was. This is because `qusi` doesn't use this collection by itself. It uses it as part of a dataset. We will explain why there's this extra layer in a moment.

## Creating a dataset

Finally, we build the dataset `qusi` uses to train the network. First, we'll take a look and then unpack it:

```python
def get_transit_train_dataset():
positive_train_light_curve_collection = LabeledLightCurveCollection.new(
positive_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_train_light_curve_collection = LabeledLightCurveCollection.new(
negative_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_negative_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
Expand Down
14 changes: 7 additions & 7 deletions examples/transit_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path

from qusi.finite_standard_light_curve_observation_dataset import FiniteStandardLightCurveObservationDataset
from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_collection import LightCurveObservationCollection
from qusi.light_curve_dataset import LightCurveDataset
from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve

Expand Down Expand Up @@ -44,11 +44,11 @@ def negative_label_function(path):


def get_transit_train_dataset():
positive_train_light_curve_collection = LabeledLightCurveCollection.new(
positive_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_train_light_curve_collection = LabeledLightCurveCollection.new(
negative_train_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_negative_train_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
Expand All @@ -59,11 +59,11 @@ def get_transit_train_dataset():


def get_transit_validation_dataset():
positive_validation_light_curve_collection = LabeledLightCurveCollection.new(
positive_validation_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_validation_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_validation_light_curve_collection = LabeledLightCurveCollection.new(
negative_validation_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_negative_validation_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
Expand All @@ -74,11 +74,11 @@ def get_transit_validation_dataset():


def get_transit_finite_test_dataset():
positive_test_light_curve_collection = LabeledLightCurveCollection.new(
positive_test_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_test_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_test_light_curve_collection = LabeledLightCurveCollection.new(
negative_test_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_negative_test_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
Expand Down
6 changes: 3 additions & 3 deletions examples/transit_infinite_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from qusi.hadryss_model import Hadryss
from qusi.device import get_device
from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_collection import LightCurveObservationCollection
from qusi.light_curve_dataset import LightCurveDataset
from ramjet.photometric_database.tess_two_minute_cadence_light_curve import TessMissionLightCurve

Expand All @@ -36,11 +36,11 @@ def negative_label_function(_path: Path) -> int:


def main():
positive_test_light_curve_collection = LabeledLightCurveCollection.new(
positive_test_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_positive_test_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_test_light_curve_collection = LabeledLightCurveCollection.new(
negative_test_light_curve_collection = LightCurveObservationCollection.new(
get_paths_function=get_negative_test_paths,
load_times_and_fluxes_from_path_function=load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
Expand Down
6 changes: 3 additions & 3 deletions src/qusi/finite_standard_light_curve_observation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
from torch.utils.data import Dataset
from typing_extensions import Self

from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_collection import LightCurveObservationCollection
from qusi.light_curve_dataset import default_light_curve_observation_post_injection_transform


@dataclass
class FiniteStandardLightCurveObservationDataset(Dataset):
standard_light_curve_collections: list[LabeledLightCurveCollection]
standard_light_curve_collections: list[LightCurveObservationCollection]
post_injection_transform: Callable[[Any], Any]
length: int
collection_start_indexes: list[int]

@classmethod
def new(cls, standard_light_curve_collections: list[LabeledLightCurveCollection]) -> Self:
def new(cls, standard_light_curve_collections: list[LightCurveObservationCollection]) -> Self:
length = 0
collection_start_indexes: list[int] = []
for standard_light_curve_collection in standard_light_curve_collections:
Expand Down
4 changes: 2 additions & 2 deletions src/qusi/light_curve_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __getitem__(self, index: int) -> LightCurve:


@dataclass
class LabeledLightCurveCollection(
class LightCurveObservationCollection(
LightCurveObservationCollectionBase, LightCurveObservationIndexableBase
):
"""
Expand Down Expand Up @@ -304,4 +304,4 @@ def constant_label_for_path_before_partial(_path: Path, label: int) -> int:
return label


LightCurveObservationCollection = LabeledLightCurveCollection
LabeledLightCurveCollection = LightCurveObservationCollection
20 changes: 10 additions & 10 deletions src/qusi/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_collection import LightCurveObservationCollection


class LightCurveDataset(IterableDataset):
Expand All @@ -48,19 +48,19 @@ class LightCurveDataset(IterableDataset):

def __init__(
self,
standard_light_curve_collections: list[LabeledLightCurveCollection],
injectee_light_curve_collections: list[LabeledLightCurveCollection],
injectable_light_curve_collections: list[LabeledLightCurveCollection],
standard_light_curve_collections: list[LightCurveObservationCollection],
injectee_light_curve_collections: list[LightCurveObservationCollection],
injectable_light_curve_collections: list[LightCurveObservationCollection],
post_injection_transform: Callable[[Any], Any],
):
self.standard_light_curve_collections: list[
LabeledLightCurveCollection
LightCurveObservationCollection
] = standard_light_curve_collections
self.injectee_light_curve_collections: list[
LabeledLightCurveCollection
LightCurveObservationCollection
] = injectee_light_curve_collections
self.injectable_light_curve_collections: list[
LabeledLightCurveCollection
LightCurveObservationCollection
] = injectable_light_curve_collections
if (
len(self.standard_light_curve_collections) == 0
Expand Down Expand Up @@ -153,11 +153,11 @@ def __iter__(self):
@classmethod
def new(
cls,
standard_light_curve_collections: list[LabeledLightCurveCollection]
standard_light_curve_collections: list[LightCurveObservationCollection]
| None = None,
injectee_light_curve_collections: list[LabeledLightCurveCollection]
injectee_light_curve_collections: list[LightCurveObservationCollection]
| None = None,
injectable_light_curve_collections: list[LabeledLightCurveCollection]
injectable_light_curve_collections: list[LightCurveObservationCollection]
| None = None,
post_injection_transform: Callable[[Any], Any] | None = None,
) -> Self:
Expand Down
10 changes: 5 additions & 5 deletions src/qusi/toy_light_curve_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset
from qusi.light_curve import LightCurve
from qusi.light_curve_collection import (
LabeledLightCurveCollection,
LightCurveObservationCollection,
create_constant_label_for_path_function, LightCurveCollection,
)
from qusi.light_curve_dataset import LightCurveDataset
Expand Down Expand Up @@ -70,16 +70,16 @@ def toy_sine_wave_light_curve_load_times_and_fluxes(
return light_curve.times, light_curve.fluxes


def get_toy_flat_light_curve_observation_collection() -> LabeledLightCurveCollection:
return LabeledLightCurveCollection.new(
def get_toy_flat_light_curve_observation_collection() -> LightCurveObservationCollection:
return LightCurveObservationCollection.new(
get_paths_function=toy_light_curve_get_paths_function,
load_times_and_fluxes_from_path_function=toy_flat_light_curve_load_times_and_fluxes,
load_label_from_path_function=create_constant_label_for_path_function(0),
)


def get_toy_sine_wave_light_curve_observation_collection() -> LabeledLightCurveCollection:
return LabeledLightCurveCollection.new(
def get_toy_sine_wave_light_curve_observation_collection() -> LightCurveObservationCollection:
return LightCurveObservationCollection.new(
get_paths_function=toy_light_curve_get_paths_function,
load_times_and_fluxes_from_path_function=toy_sine_wave_light_curve_load_times_and_fluxes,
load_label_from_path_function=create_constant_label_for_path_function(1),
Expand Down

0 comments on commit 3e1c7c7

Please sign in to comment.