Skip to content

Commit 250f2a6

Browse files
committed
Move transforms to transforms file
1 parent 60bff1d commit 250f2a6

File tree

3 files changed

+50
-53
lines changed

3 files changed

+50
-53
lines changed

src/qusi/internal/light_curve_dataset.py

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333
from qusi.internal.light_curve_transforms import (
3434
from_light_curve_observation_to_fluxes_array_and_label_array,
35-
pair_array_to_tensor,
35+
pair_array_to_tensor, normalize_tensor_by_modified_z_score, make_uniform_length,
3636
)
3737

3838
if TYPE_CHECKING:
@@ -337,56 +337,6 @@ def default_light_curve_post_injection_transform(
337337
return x
338338

339339

340-
def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
341-
"""
342-
Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median
343-
absolute deviation.
344-
345-
:param tensor: The tensor to normalize.
346-
:return: The normalized tensor.
347-
"""
348-
median = torch.median(tensor)
349-
deviation_from_median = tensor - median
350-
absolute_deviation_from_median = torch.abs(deviation_from_median)
351-
median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median)
352-
if median_absolute_deviation_from_median != 0:
353-
modified_z_score = (
354-
0.6745 * deviation_from_median / median_absolute_deviation_from_median
355-
)
356-
else:
357-
modified_z_score = torch.zeros_like(tensor)
358-
return modified_z_score
359-
360-
361-
def make_uniform_length(
362-
example: np.ndarray, length: int, *, randomize: bool = True
363-
) -> np.ndarray:
364-
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
365-
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
366-
raise ValueError(
367-
f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}"
368-
)
369-
if randomize:
370-
example = randomly_roll_elements(example)
371-
if example.shape[0] == length:
372-
pass
373-
elif example.shape[0] > length:
374-
example = example[:length]
375-
else:
376-
elements_to_repeat = length - example.shape[0]
377-
if len(example.shape) == 1:
378-
example = np.pad(example, (0, elements_to_repeat), mode="wrap")
379-
else:
380-
example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap")
381-
return example
382-
383-
384-
def randomly_roll_elements(example: np.ndarray) -> np.ndarray:
385-
"""Randomly rolls the elements."""
386-
example = np.roll(example, np.random.randint(example.shape[0]), axis=0)
387-
return example
388-
389-
390340
class OutOfBoundsInjectionHandlingMethod(Enum):
391341
"""
392342
An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal.

src/qusi/internal/light_curve_transforms.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import numpy as np
24
import numpy.typing as npt
35
import torch
@@ -38,3 +40,47 @@ def randomly_roll_elements(example: np.ndarray) -> np.ndarray:
3840
"""Randomly rolls the elements."""
3941
example = np.roll(example, np.random.randint(example.shape[0]), axis=0)
4042
return example
43+
44+
45+
def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
46+
"""
47+
Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median
48+
absolute deviation.
49+
50+
:param tensor: The tensor to normalize.
51+
:return: The normalized tensor.
52+
"""
53+
median = torch.median(tensor)
54+
deviation_from_median = tensor - median
55+
absolute_deviation_from_median = torch.abs(deviation_from_median)
56+
median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median)
57+
if median_absolute_deviation_from_median != 0:
58+
modified_z_score = (
59+
0.6745 * deviation_from_median / median_absolute_deviation_from_median
60+
)
61+
else:
62+
modified_z_score = torch.zeros_like(tensor)
63+
return modified_z_score
64+
65+
66+
def make_uniform_length(
67+
example: np.ndarray, length: int, *, randomize: bool = True
68+
) -> np.ndarray:
69+
"""Makes the example a specific length, by clipping those too large and repeating those too small."""
70+
if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases.
71+
raise ValueError(
72+
f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}"
73+
)
74+
if randomize:
75+
example = randomly_roll_elements(example)
76+
if example.shape[0] == length:
77+
pass
78+
elif example.shape[0] > length:
79+
example = example[:length]
80+
else:
81+
elements_to_repeat = length - example.shape[0]
82+
if len(example.shape) == 1:
83+
example = np.pad(example, (0, elements_to_repeat), mode="wrap")
84+
else:
85+
example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap")
86+
return example

src/qusi/transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
"""
44
from qusi.internal.light_curve import randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve
55
from qusi.internal.light_curve_dataset import default_light_curve_post_injection_transform, \
6-
default_light_curve_observation_post_injection_transform, make_uniform_length
6+
default_light_curve_observation_post_injection_transform
77
from qusi.internal.light_curve_observation import remove_nan_flux_data_points_from_light_curve_observation, \
88
randomly_roll_light_curve_observation
99
from qusi.internal.light_curve_transforms import from_light_curve_observation_to_fluxes_array_and_label_array, \
10-
pair_array_to_tensor
10+
pair_array_to_tensor, make_uniform_length, normalize_tensor_by_modified_z_score
1111

1212
__all__ = [
1313
'default_light_curve_post_injection_transform',
1414
'default_light_curve_observation_post_injection_transform',
1515
'from_light_curve_observation_to_fluxes_array_and_label_array',
1616
'make_uniform_length',
17+
'normalize_tensor_by_modified_z_score',
1718
'pair_array_to_tensor',
1819
'randomly_roll_light_curve',
1920
'randomly_roll_light_curve_observation',

0 commit comments

Comments
 (0)