Skip to content

Commit

Permalink
Correct various ruff fmt issues
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Feb 14, 2024
1 parent aceaf58 commit 0b90e70
Show file tree
Hide file tree
Showing 30 changed files with 207 additions and 184 deletions.
4 changes: 2 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Configuration for the pytest tests."""

import matplotlib
import matplotlib as mpl
import pytest

matplotlib.use('Agg') # Use non-interactive backend to prevent loss of focus during test.
mpl.use('Agg') # Use non-interactive backend to prevent loss of focus during test.


def pytest_addoption(parser):
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,5 +107,8 @@ exclude_lines = [
extend-exclude = ["examples"]

[tool.ruff.lint]
ignore = ["RET504"]
ignore = [
"RET504", # Subjective but, naming the returned value often seems to help readability.
"SIM108", # Subjective but, ternary operators are often too confusing.
]
isort.known-first-party = ["qusi", "ramjet"]
2 changes: 1 addition & 1 deletion src/qusi/hadryss_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def determine_block_pooling_sizes_and_dense_size(self) -> (list[int], int):
number_of_pooling_blocks = 9
pooling_sizes = [1] * number_of_pooling_blocks
while True:
for pooling_size_index, pooling_size in enumerate(pooling_sizes):
for pooling_size_index, _pooling_size in enumerate(pooling_sizes):
current_size = self.input_length
for current_pooling_size in pooling_sizes:
current_size -= 2
Expand Down
6 changes: 4 additions & 2 deletions src/qusi/light_curve_collection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from pathlib import Path
from random import Random
from typing import Callable
from typing import TYPE_CHECKING, Callable

import numpy as np
import numpy.typing as npt
Expand All @@ -14,6 +13,9 @@
from qusi.light_curve import LightCurve
from qusi.light_curve_observation import LightCurveObservation

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator


class LightCurveCollectionBase(ABC):
@abstractmethod
Expand Down
36 changes: 21 additions & 15 deletions src/qusi/light_curve_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import re
import shutil
import socket
from collections.abc import Iterable, Iterator
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, Callable, TypeVar
from typing import TYPE_CHECKING, Any, Callable, TypeVar

import numpy as np
import numpy.typing as npt
Expand All @@ -22,7 +21,6 @@
from typing_extensions import Self

from qusi.light_curve import LightCurve, randomly_roll_light_curve, remove_nan_flux_data_points_from_light_curve
from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_observation import (
LightCurveObservation,
randomly_roll_light_curve_observation,
Expand All @@ -33,6 +31,11 @@
pair_array_to_tensor,
)

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

from qusi.light_curve_collection import LabeledLightCurveCollection


class LightCurveDataset(IterableDataset):
"""
Expand All @@ -49,8 +52,11 @@ def __init__(self,
self.injectee_light_curve_collections: list[LabeledLightCurveCollection] = injectee_light_curve_collections
self.injectable_light_curve_collections: list[LabeledLightCurveCollection] = injectable_light_curve_collections
if len(self.standard_light_curve_collections) == 0 and len(self.injectee_light_curve_collections) == 0:
raise ValueError('Either the standard or injectee light curve collection lists must not be empty. '
'Both were empty.')
error_message = (
'Either the standard or injectee light curve collection lists must not be empty. '
'Both were empty.'
)
raise ValueError(error_message)
self.post_injection_transform: Callable[[Any], Any] = post_injection_transform

def __iter__(self):
Expand Down Expand Up @@ -99,8 +105,11 @@ def new(cls,
post_injection_transform: Callable[[Any], Any] | None = None,
) -> Self:
if standard_light_curve_collections is None and injectee_light_curve_collections is None:
raise ValueError('Either the standard or injectee light curve collection lists must be specified. '
'Both were `None`.')
error_message = (
'Either the standard or injectee light curve collection lists must be specified. '
'Both were `None`.'
)
raise ValueError(error_message)
if standard_light_curve_collections is None:
standard_light_curve_collections = []
if injectee_light_curve_collections is None:
Expand Down Expand Up @@ -138,10 +147,7 @@ def is_injected_dataset(dataset: LightCurveDataset):


def contains_injected_dataset(datasets: list[LightCurveDataset]):
for dataset in datasets:
if is_injected_dataset(dataset):
return True
return False
return any(is_injected_dataset(dataset) for dataset in datasets)


def interleave_infinite_iterators(*infinite_iterators: Iterator):
Expand All @@ -156,8 +162,7 @@ def interleave_infinite_iterators(*infinite_iterators: Iterator):
def loop_iter_function(iter_function: Callable[[], Iterable[T]]) -> Iterator[T]:
while True:
iterator = iter_function()
for element in iterator:
yield element
yield from iterator


class ObservationType(Enum):
Expand Down Expand Up @@ -197,8 +202,7 @@ def new(cls, *datasets: IterableDataset):

def __iter__(self):
for dataset in self.datasets:
for element in dataset:
yield element
yield from dataset


class LimitedIterableDataset(IterableDataset):
Expand Down Expand Up @@ -229,6 +233,7 @@ def default_light_curve_observation_post_injection_transform(x: LightCurveObserv
x = (normalize_tensor_by_modified_z_score(x[0]), x[1])
return x


def default_light_curve_post_injection_transform(x: LightCurve, length: int) -> (Tensor):
x = remove_nan_flux_data_points_from_light_curve(x)
x = randomly_roll_light_curve(x)
Expand All @@ -238,6 +243,7 @@ def default_light_curve_post_injection_transform(x: LightCurve, length: int) ->
x = normalize_tensor_by_modified_z_score(x)
return x


def normalize_tensor_by_modified_z_score(tensor: Tensor) -> Tensor:
median = torch.median(tensor)
deviation_from_median = tensor - median
Expand Down
2 changes: 2 additions & 0 deletions src/qusi/train_logging_configuration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

Expand Down
2 changes: 1 addition & 1 deletion src/qusi/train_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def train_session(train_datasets: list[LightCurveDataset],
for metric_function in metric_functions:
metric_functions_on_device.append(metric_function.to(device, non_blocking=True))
metric_functions = metric_functions_on_device
for cycle_index in range(hyperparameter_configuration.cycles):
for _cycle_index in range(hyperparameter_configuration.cycles):
train_phase(dataloader=train_dataloader, model=model, loss_function=loss_function,
metric_functions=metric_functions, optimizer=optimizer,
steps=hyperparameter_configuration.train_steps_per_cycle, device=device)
Expand Down
2 changes: 1 addition & 1 deletion src/qusi/wandb_liaison.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

import wandb as wandb
import wandb


def wandb_log(name: str, value: Any, process_rank: int):
Expand Down
10 changes: 6 additions & 4 deletions src/ramjet/analysis/light_curve_folding_vizualizer/viewer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

from typing import Optional
from typing import TYPE_CHECKING

import numpy as np
from astropy import units
from bokeh.document import Document
from bokeh.events import Tap
from bokeh.palettes import Turbo256
from lightkurve.periodogram import LombScarglePeriodogram, Periodogram
Expand All @@ -17,7 +16,10 @@
from bokeh.models import ColumnDataSource, Div, LinearAxis, LinearColorMapper, Range1d, Span, Spinner, TapTool
from bokeh.plotting import figure as Figure

from ramjet.photometric_database.light_curve import LightCurve
if TYPE_CHECKING:
from bokeh.document import Document

from ramjet.photometric_database.light_curve import LightCurve


class FoldedLightCurveColumnName(StrEnum):
Expand All @@ -32,7 +34,7 @@ class PeriodogramColumnName(StrEnum):


class Viewer:
def __init__(self, bokeh_document: Document, light_curve: LightCurve, title: Optional[str] = None):
def __init__(self, bokeh_document: Document, light_curve: LightCurve, title: str | None = None):
self.bokeh_document: Document = bokeh_document
tool_tips = [
("Time", f"@{FoldedLightCurveColumnName.TIME}{{0.0000000}}"),
Expand Down
6 changes: 3 additions & 3 deletions src/ramjet/analysis/light_curve_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from pathlib import Path
from typing import Union
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -15,9 +15,9 @@


def plot_light_curve(times: np.ndarray, fluxes: np.ndarray, labels: np.ndarray = None, predictions: np.ndarray = None,
title: str = None, x_label: str = 'Days', y_label: str = 'Flux',
title: Optional[str] = None, x_label: str = 'Days', y_label: str = 'Flux',
x_limits: (float, float) = (None, None), y_limits: (float, float) = (None, None),
save_path: Union[Path, str] = None, exclude_flux_outliers: bool = False,
save_path: Optional[Union[Path, str]] = None, exclude_flux_outliers: bool = False,
base_data_point_size: float = 3):
"""
Plots a light curve with a consistent styling. If true labels and/or predictions are included, these will
Expand Down
3 changes: 2 additions & 1 deletion src/ramjet/analysis/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def get_latest_log_directory(logs_directory: Union[Path, str]) -> Path:
latest_log_directory = log_directory
latest_log_datetime = log_datetime
if latest_log_directory is None:
raise FileNotFoundError(f'No logs with datetime names found in {logs_directory}')
error_message = f'No logs with datetime names found in {logs_directory}'
raise FileNotFoundError(error_message)
else:
return latest_log_directory
4 changes: 2 additions & 2 deletions src/ramjet/analysis/roc_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Code for a class to calculate receiver operating characteristic (ROC) curves.
"""
from pathlib import Path
from typing import Union
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -82,7 +82,7 @@ def accumulate_confusion_matrix_counts(self, label: np.ndarray, prediction: np.n
self.true_negative_counts += true_negatives
self.false_negative_counts += false_negatives

def generate_roc_plot(self, output_path: Union[str, Path] = 'roc_curve.svg', title: str = None):
def generate_roc_plot(self, output_path: Union[str, Path] = 'roc_curve.svg', title: Optional[str] = None):
"""
Generates a ROC curve plot from the confusion matrix totals which have been accumulated.
Expand Down
18 changes: 5 additions & 13 deletions src/ramjet/analysis/transit_vetter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def is_transit_depth_for_target_physical_for_planet(self, target: TessTarget, tr
"""
transiting_body_radius = target.calculate_transiting_body_radius(transit_depth)
planet_radius_threshold = 1.8 * self.radius_of_jupiter__solar_radii
if transiting_body_radius < planet_radius_threshold:
return True
else:
return False
return transiting_body_radius < planet_radius_threshold

@staticmethod
def has_no_nearby_likely_eclipsing_binary_background_targets(target: TessTarget) -> bool:
Expand All @@ -43,10 +40,7 @@ def has_no_nearby_likely_eclipsing_binary_background_targets(target: TessTarget)
(nearby_target_data_frame['TESS Mag'] < target.magnitude + magnitude_difference_threshold) &
(nearby_target_data_frame['Separation (arcsec)'] < nearby_threshold_arcseconds)
]
if problematic_nearby_target_data_frame.shape[0] == 0:
return True
else:
return False
return problematic_nearby_target_data_frame.shape[0] == 0

@staticmethod
def has_nearby_toi_targets(target: TessTarget) -> bool:
Expand All @@ -62,10 +56,7 @@ def has_nearby_toi_targets(target: TessTarget) -> bool:
(pd.notnull(nearby_target_data_frame['TOI'])) &
(nearby_target_data_frame['Separation (arcsec)'] < nearby_threshold_arcseconds)
]
if problematic_nearby_target_data_frame.shape[0] != 0:
return True
else:
return False
return problematic_nearby_target_data_frame.shape[0] != 0

def get_maximum_physical_depth_for_planet_for_target(self, target: TessTarget,
allow_missing_contamination_ratio: bool = False) -> float:
Expand All @@ -82,7 +73,8 @@ def get_maximum_physical_depth_for_planet_for_target(self, target: TessTarget,
if allow_missing_contamination_ratio:
contamination_ratio = 0
else:
raise ValueError(f'Contamination ratio {contamination_ratio} is not a number.')
error_message = f'Contamination ratio {contamination_ratio} is not a number.'
raise ValueError(error_message)
maximum_physical_depth = (maximum_planet_radius ** 2) / (
(target.radius ** 2) * (1 + contamination_ratio))
return maximum_physical_depth
4 changes: 2 additions & 2 deletions src/ramjet/analysis/viewer/light_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def for_columns(cls, time_column_name: str, flux_column_names: list[str], flux_a
display.flux_column_names = flux_column_names
time_axis_label = convert_column_name_to_display_name(time_column_name)
display.initialize_figure(time_axis_label=time_axis_label, flux_axis_label=flux_axis_label)
display.initialize_data_source(column_names=[time_column_name] + flux_column_names)
display.initialize_data_source(column_names=[time_column_name, *flux_column_names])
for flux_column_name, color in zip(display.flux_column_names, light_curve_colors):
display.add_flux_data_source_line_to_figure(time_column_name=time_column_name, flux_column_name=flux_column_name,
color=color)
Expand All @@ -67,7 +67,7 @@ def initialize_data_source(self, column_names: list[str]):
:param column_names: The column names to include in the data source.
"""
self.data_source = ColumnDataSource(data=pd.DataFrame({column_name: [] for column_name in column_names}))
js_reset = CustomJS(args=dict(figure=self.figure), code='figure.reset.emit()')
js_reset = CustomJS(args={'figure': self.figure}, code='figure.reset.emit()')
self.data_source.js_on_change('data', js_reset)

def add_flux_data_source_line_to_figure(self, time_column_name: str, flux_column_name: str, color: Color):
Expand Down
11 changes: 5 additions & 6 deletions src/ramjet/analysis/viewer/preloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
Code to load view entities in the background so they show up quickly when displayed.
"""
import asyncio
import contextlib
import warnings
from asyncio import Task
from collections import deque
from pathlib import Path
from typing import Deque, Union
from typing import Union

import pandas as pd

Expand All @@ -23,8 +24,8 @@ class Preloader:

def __init__(self):
self.current_view_entity: Union[None, ViewEntity] = None
self.next_view_entity_deque: Deque[ViewEntity] = deque(maxlen=self.maximum_preloaded)
self.previous_view_entity_deque: Deque[ViewEntity] = deque(maxlen=self.maximum_preloaded)
self.next_view_entity_deque: deque[ViewEntity] = deque(maxlen=self.maximum_preloaded)
self.previous_view_entity_deque: deque[ViewEntity] = deque(maxlen=self.maximum_preloaded)
self.identifier_data_frame: Union[pd.DataFrame, None] = None
self.running_loading_task: Union[Task, None] = None

Expand Down Expand Up @@ -125,10 +126,8 @@ async def cancel_loading_task(self):
"""
if self.running_loading_task is not None:
self.running_loading_task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await self.running_loading_task
except asyncio.CancelledError:
pass

async def reset_deques(self):
"""
Expand Down
Loading

0 comments on commit 0b90e70

Please sign in to comment.