Skip to content

Commit 63fcfc5

Browse files
committed
Remove randomization from make_uniform_length and leave that to a separate randomly roll call
1 parent 250f2a6 commit 63fcfc5

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

docs/source/tutorials/crafting_standard_datasets.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv
3939
if randomize:
4040
x = randomly_roll_light_curve_observation(x)
4141
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
42-
x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length.
42+
x = (make_uniform_length(x[0], length=length), x[1])
4343
x = pair_array_to_tensor(x)
4444
x = (normalize_tensor_by_modified_z_score(x[0]), x[1])
4545
return x

src/qusi/internal/light_curve_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def default_light_curve_observation_post_injection_transform(
306306
if randomize:
307307
x = randomly_roll_light_curve_observation(x)
308308
x = from_light_curve_observation_to_fluxes_array_and_label_array(x)
309-
x = (make_uniform_length(x[0], length=length, randomize=randomize), x[1]) # Make the fluxes a uniform length.
309+
x = (make_uniform_length(x[0], length=length), x[1]) # Make the fluxes a uniform length.
310310
x = pair_array_to_tensor(x)
311311
x = (normalize_tensor_by_modified_z_score(x[0]), x[1])
312312
return x
@@ -331,7 +331,7 @@ def default_light_curve_post_injection_transform(
331331
if randomize:
332332
x = randomly_roll_light_curve(x)
333333
x = x.fluxes
334-
x = make_uniform_length(x, length=length, randomize=randomize)
334+
x = make_uniform_length(x, length=length)
335335
x = torch.tensor(x, dtype=torch.float32)
336336
x = normalize_tensor_by_modified_z_score(x)
337337
return x

src/qusi/internal/light_curve_transforms.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,12 @@ def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
6363
return modified_z_score
6464

6565

66-
def make_uniform_length(
67-
example: np.ndarray, length: int, *, randomize: bool = True
68-
) -> np.ndarray:
66+
def make_uniform_length(example: np.ndarray, length: int) -> np.ndarray:
6967
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
7068
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
7169
raise ValueError(
7270
f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}"
7371
)
74-
if randomize:
75-
example = randomly_roll_elements(example)
7672
if example.shape[0] == length:
7773
pass
7874
elif example.shape[0] > length:

0 commit comments

Comments
 (0)