|
32 | 32 | )
|
33 | 33 | from qusi.internal.light_curve_transforms import (
|
34 | 34 | from_light_curve_observation_to_fluxes_array_and_label_array,
|
35 |
| - pair_array_to_tensor, |
| 35 | + pair_array_to_tensor, normalize_tensor_by_modified_z_score, make_uniform_length, |
36 | 36 | )
|
37 | 37 |
|
38 | 38 | if TYPE_CHECKING:
|
@@ -337,56 +337,6 @@ def default_light_curve_post_injection_transform(
|
337 | 337 | return x
|
338 | 338 |
|
339 | 339 |
|
340 |
| -def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor: |
341 |
| - """ |
342 |
| - Normalizes a tensor by a modified z-score. That is, normalizes the values of the tensor based on the median |
343 |
| - absolute deviation. |
344 |
| -
|
345 |
| - :param tensor: The tensor to normalize. |
346 |
| - :return: The normalized tensor. |
347 |
| - """ |
348 |
| - median = torch.median(tensor) |
349 |
| - deviation_from_median = tensor - median |
350 |
| - absolute_deviation_from_median = torch.abs(deviation_from_median) |
351 |
| - median_absolute_deviation_from_median = torch.median(absolute_deviation_from_median) |
352 |
| - if median_absolute_deviation_from_median != 0: |
353 |
| - modified_z_score = ( |
354 |
| - 0.6745 * deviation_from_median / median_absolute_deviation_from_median |
355 |
| - ) |
356 |
| - else: |
357 |
| - modified_z_score = torch.zeros_like(tensor) |
358 |
| - return modified_z_score |
359 |
| - |
360 |
| - |
361 |
| -def make_uniform_length( |
362 |
| - example: np.ndarray, length: int, *, randomize: bool = True |
363 |
| -) -> np.ndarray: |
364 |
| - """Makes the example a specific length, by clipping those too large and repeating those too small.""" |
365 |
| - if len(example.shape) not in [1, 2]: # Only tested for 1D and 2D cases. |
366 |
| - raise ValueError( |
367 |
| - f"Light curve dimensions expected to be in [1, 2], but found {len(example.shape)}" |
368 |
| - ) |
369 |
| - if randomize: |
370 |
| - example = randomly_roll_elements(example) |
371 |
| - if example.shape[0] == length: |
372 |
| - pass |
373 |
| - elif example.shape[0] > length: |
374 |
| - example = example[:length] |
375 |
| - else: |
376 |
| - elements_to_repeat = length - example.shape[0] |
377 |
| - if len(example.shape) == 1: |
378 |
| - example = np.pad(example, (0, elements_to_repeat), mode="wrap") |
379 |
| - else: |
380 |
| - example = np.pad(example, ((0, elements_to_repeat), (0, 0)), mode="wrap") |
381 |
| - return example |
382 |
| - |
383 |
| - |
384 |
| -def randomly_roll_elements(example: np.ndarray) -> np.ndarray: |
385 |
| - """Randomly rolls the elements.""" |
386 |
| - example = np.roll(example, np.random.randint(example.shape[0]), axis=0) |
387 |
| - return example |
388 |
| - |
389 |
| - |
390 | 340 | class OutOfBoundsInjectionHandlingMethod(Enum):
|
391 | 341 | """
|
392 | 342 | An enum of approaches for handling cases where the injectable signal is shorter than the injectee signal.
|
|
0 commit comments