diff --git a/pyproject.toml b/pyproject.toml index 9545a90e..692b52de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,5 +114,8 @@ ignore = [ "S608", "S301", "S311", + "EM102", # The improved source readability is worth the loss in readibility of the traceback in my opinion. + "TRY003", # Disabling EM102 makes this rule trigger in areas it shouldn't. + "G004", # The improved source readability is worth the extra string evaluation cost in my opinion. ] isort.known-first-party = ["qusi", "ramjet"] \ No newline at end of file diff --git a/src/ramjet/data_interface/tess_data_interface.py b/src/ramjet/data_interface/tess_data_interface.py index ea8101b3..20601867 100644 --- a/src/ramjet/data_interface/tess_data_interface.py +++ b/src/ramjet/data_interface/tess_data_interface.py @@ -68,7 +68,10 @@ def is_common_mast_connection_error(exception: Exception) -> bool: """ print(f'Retrying on {exception}...', flush=True) # TODO: Rename function, as it includes more than just MAST now. - return (isinstance(exception, (AstroQueryTimeoutError, ConnectionResetError, TimeoutError, astroquery.exceptions.RemoteServiceError, lightkurve.search.SearchError, requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.ReadTimeout))) + return (isinstance(exception, ( + AstroQueryTimeoutError, ConnectionResetError, TimeoutError, astroquery.exceptions.RemoteServiceError, + lightkurve.search.SearchError, requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, requests.exceptions.ReadTimeout))) class NoDataProductsFoundException(Exception): @@ -314,7 +317,10 @@ def load_fluxes_and_times_from_fits_file(light_curve_path: str | Path, light_curve = load_light_curve_from_fits_file(light_curve_path) fluxes = light_curve[flux_type.value] times = light_curve['TIME'] - assert times.shape == fluxes.shape + if times.shape != fluxes.shape: + error_message = f'Times and fluxes arrays must have the same shape, but have shapes ' \ + f'{times.shape} and {fluxes.shape}.' + raise ValueError(error_message) if remove_nans: # noinspection PyUnresolvedReferences nan_indexes = np.union1d(np.argwhere(np.isnan(fluxes)), np.argwhere(np.isnan(times))) @@ -325,21 +331,24 @@ def load_fluxes_and_times_from_fits_file(light_curve_path: str | Path, def load_fluxes_flux_errors_and_times_from_fits_file(light_curve_path: str | Path, flux_type: TessFluxType = TessFluxType.PDCSAP, - remove_nans: bool = True + *, remove_nans: bool = True ) -> (np.ndarray, np.ndarray, np.ndarray): """ Extract the flux and time values from a TESS FITS file. :param light_curve_path: The path to the FITS file. :param flux_type: The flux type to extract from the FITS file. - :param remove_nans: Whether or not to remove nans. + :param remove_nans: Whether to remove nans. :return: The flux and times values from the FITS file. """ light_curve = load_light_curve_from_fits_file(light_curve_path) fluxes = light_curve[flux_type.value] flux_errors = light_curve[flux_type.value + '_ERR'] times = light_curve['TIME'] - assert times.shape == fluxes.shape + if times.shape != fluxes.shape: + error_message = f'Times and fluxes arrays must have the same shape, but have shapes ' \ + f'{times.shape} and {fluxes.shape}.' + raise ValueError(error_message) if remove_nans: # noinspection PyUnresolvedReferences nan_indexes = np.union1d(np.argwhere(np.isnan(fluxes)), np.union1d(np.argwhere(np.isnan(times)), @@ -350,7 +359,7 @@ def load_fluxes_flux_errors_and_times_from_fits_file(light_curve_path: str | Pat return fluxes, flux_errors, times -def plot_light_curve_from_mast(tic_id: int, sector: int | None = None, exclude_flux_outliers: bool = False, +def plot_light_curve_from_mast(tic_id: int, sector: int | None = None, *, exclude_flux_outliers: bool = False, base_data_point_size=3): """ Downloads and plots a light curve from MAST. @@ -454,8 +463,7 @@ def get_variable_data_frame_for_coordinates(coordinates, radius='21s') -> pd.Dat variable_table_list = Vizier.query_region(coordinates, radius=radius, catalog='B/gcvs/gcvs_cat') if len(variable_table_list) > 0: return variable_table_list[0].to_pandas() - else: - return pd.DataFrame() + return pd.DataFrame() @retry(retry_on_exception=is_common_mast_connection_error, stop_max_attempt_number=10) @@ -530,8 +538,10 @@ def get_all_tess_spoc_light_curve_observations_chunk(tic_id: int | list[int]) -> def get_spoc_tic_id_list_from_mast() -> list[int]: sector_data_frames: list[pl.DataFrame] = [] for sector_index in itertools.count(1): - response = requests.get(f'https://archive.stsci.edu/hlsps/tess-spoc/target_lists/s{sector_index:04d}.csv') - if response.status_code != 200: + response = requests.get(f'https://archive.stsci.edu/hlsps/tess-spoc/target_lists/s{sector_index:04d}.csv', + timeout=600) + success_code = 200 + if response.status_code != success_code: break csv_string = response.text[1:] # Remove hashtag from header. sector_data_frame = pl.read_csv(StringIO(csv_string)) @@ -606,10 +616,10 @@ def initialize_astroquery(): Catalogs.TIMEOUT = 2000 Catalogs.PAGESIZE = 3000 try: # Temporary fix for astroquery's update of timeout and pagesize locations. - Observations._portal_api_connection.TIMEOUT = 2000 - Observations._portal_api_connection.PAGESIZE = 3000 - Catalogs._portal_api_connection.TIMEOUT = 2000 - Catalogs._portal_api_connection.PAGESIZE = 3000 + Observations._portal_api_connection.TIMEOUT = 2000 # noqa SLF001 + Observations._portal_api_connection.PAGESIZE = 3000 # noqa SLF001 + Catalogs._portal_api_connection.TIMEOUT = 2000 # noqa SLF001 + Catalogs._portal_api_connection.PAGESIZE = 3000 # noqa SLF001 except AttributeError: pass diff --git a/src/ramjet/data_interface/tess_eclipsing_binary_metadata_manager.py b/src/ramjet/data_interface/tess_eclipsing_binary_metadata_manager.py index c0443210..679f9b8c 100644 --- a/src/ramjet/data_interface/tess_eclipsing_binary_metadata_manager.py +++ b/src/ramjet/data_interface/tess_eclipsing_binary_metadata_manager.py @@ -1,6 +1,7 @@ """ Code for managing the TESS eclipsing binary metadata. """ +import logging from pathlib import Path import pandas as pd @@ -10,6 +11,8 @@ brian_powell_eclipsing_binary_csv_path = Path('data/tess_eclipsing_binaries/TESS_EB_catalog_23Jun.csv') +logger = logging.getLogger(__name__) + class TessEclipsingBinaryMetadata(MetadatabaseModel): """ @@ -27,7 +30,7 @@ def build_table(): """ Builds the TESS eclipsing binary metadata table. """ - print('Building TESS eclipsing binary metadata table...') + logger.info('Building TESS eclipsing binary metadata table...') eclipsing_binary_data_frame = pd.read_csv(brian_powell_eclipsing_binary_csv_path, usecols=['ID']) row_count = 0 metadatabase.drop_tables([TessEclipsingBinaryMetadata]) @@ -45,7 +48,7 @@ def build_table(): with metadatabase.atomic(): TessEclipsingBinaryMetadata.insert_many(rows).execute() SchemaManager(TessEclipsingBinaryMetadata).create_indexes() - print(f'Table built. {row_count} rows added.') + logger.info(f'Table built. {row_count} rows added.') if __name__ == '__main__': diff --git a/src/ramjet/data_interface/tess_ffi_light_curve_metadata_manager.py b/src/ramjet/data_interface/tess_ffi_light_curve_metadata_manager.py index aa3ecdcb..265ee082 100644 --- a/src/ramjet/data_interface/tess_ffi_light_curve_metadata_manager.py +++ b/src/ramjet/data_interface/tess_ffi_light_curve_metadata_manager.py @@ -2,6 +2,7 @@ Code for managing the TESS FFI metadata SQL table. """ import itertools +import logging from pathlib import Path from peewee import CharField, FloatField, IntegerField, SchemaManager @@ -15,6 +16,8 @@ ) from ramjet.photometric_database.tess_ffi_light_curve import TessFfiLightCurve +logger = logging.getLogger(__name__) + class TessFfiLightCurveMetadata(MetadatabaseModel): """ @@ -72,7 +75,7 @@ def populate_sql_database(self): """ Populates the SQL database based on the light curve files. """ - print('Populating the TESS FFI light curve meta data table...', flush=True) + logger.info('Populating the TESS FFI light curve meta data table...') single_sector_path_globs = [] for sector in range(1, 27): single_sector_path_glob = self.light_curve_root_directory_path.glob( @@ -91,10 +94,10 @@ def populate_sql_database(self): if index % 1000 == 0 and index != 0: self.insert_multiple_rows_from_paths_into_database(batch_paths) batch_paths = [] - print(f'{index} rows inserted...', end='\r', flush=True) + logger.info(f'{index} rows inserted...') if len(batch_paths) > 0: self.insert_multiple_rows_from_paths_into_database(batch_paths) - print(f'TESS FFI light curve meta data table populated. {row_count} rows added.', flush=True) + logger.info(f'TESS FFI light curve meta data table populated. {row_count} rows added.') def build_table(self): """ diff --git a/src/ramjet/data_interface/tess_target_metadata_manager.py b/src/ramjet/data_interface/tess_target_metadata_manager.py index be4a4fa6..3b57130b 100644 --- a/src/ramjet/data_interface/tess_target_metadata_manager.py +++ b/src/ramjet/data_interface/tess_target_metadata_manager.py @@ -1,6 +1,7 @@ """ Code for managing the metadata of the TESS targets. """ +import logging from pathlib import Path from peewee import IntegerField @@ -12,7 +13,9 @@ metadatabase, metadatabase_uuid, ) -from ramjet.data_interface.tess_data_interface import TessDataInterface +from ramjet.data_interface.tess_data_interface import get_tic_id_and_sector_from_file_path + +logger = logging.getLogger(__name__) class TessTargetMetadata(MetadatabaseModel): @@ -27,8 +30,6 @@ class TessTargetMetadataManger: """ A class for managing the metadata of TESS targets. """ - tess_data_interface = TessDataInterface() - def __init__(self): self.light_curve_root_directory_path = Path('data/tess_two_minute_cadence_light_curves') @@ -56,7 +57,7 @@ def populate_sql_database(self): """ Populates the SQL database based on the light curve files. """ - print('Populating the TESS target light curve metadata table...') + logger.info('Populating the TESS target light curve metadata table...') path_glob = self.light_curve_root_directory_path.glob('**/*.fits') row_count = 0 batch_paths = [] @@ -69,10 +70,10 @@ def populate_sql_database(self): row_count += self.insert_multiple_rows_from_paths_into_database(batch_paths) batch_paths = [] batch_dataset_splits = [] - print(f'{row_count} rows inserted...', end='\r') + logger.info(f'{row_count} rows inserted...') if len(batch_paths) > 0: row_count += self.insert_multiple_rows_from_paths_into_database(batch_paths) - print(f'TESS target metadata table populated. {row_count} rows added.') + logger.info(f'TESS target metadata table populated. {row_count} rows added.') def build_table(self): """ diff --git a/src/ramjet/data_interface/tess_toi_data_interface.py b/src/ramjet/data_interface/tess_toi_data_interface.py index 02265e69..102470ce 100644 --- a/src/ramjet/data_interface/tess_toi_data_interface.py +++ b/src/ramjet/data_interface/tess_toi_data_interface.py @@ -1,7 +1,8 @@ -import warnings +from __future__ import annotations + +import logging from enum import Enum from pathlib import Path -from typing import Union import pandas as pd import requests @@ -12,6 +13,7 @@ get_product_list, ) +logger = logging.getLogger(__name__) class ToiColumns(Enum): """ @@ -45,8 +47,8 @@ def __init__(self, data_directory='data/tess_toi'): self.toi_dispositions_path = self.data_directory.joinpath('toi_dispositions.csv') self.ctoi_dispositions_path = self.data_directory.joinpath('ctoi_dispositions.csv') self.light_curves_directory = self.data_directory.joinpath('light_curves') - self.toi_dispositions_: Union[pd.DataFrame, None] = None - self.ctoi_dispositions_: Union[pd.DataFrame, None] = None + self.toi_dispositions_: pd.DataFrame | None = None + self.ctoi_dispositions_: pd.DataFrame | None = None @property def toi_dispositions(self): @@ -60,7 +62,7 @@ def toi_dispositions(self): try: self.update_toi_dispositions_file() except requests.exceptions.ConnectionError: - warnings.warn('Unable to connect to update TOI file. Attempting to use existing file...') + logger.warning('Unable to connect to update TOI file. Attempting to use existing file...') self.toi_dispositions_ = self.load_toi_dispositions_in_project_format() return self.toi_dispositions_ @@ -69,7 +71,7 @@ def update_toi_dispositions_file(self): Downloads the latest TOI dispositions file. """ toi_csv_url = 'https://exofop.ipac.caltech.edu/tess/download_toi.php?sort=toi&output=csv' - response = requests.get(toi_csv_url) + response = requests.get(toi_csv_url, timeout=600) with self.toi_dispositions_path.open('wb') as csv_file: csv_file.write(response.content) @@ -85,7 +87,7 @@ def ctoi_dispositions(self): try: self.update_ctoi_dispositions_file() except requests.exceptions.ConnectionError: - warnings.warn('Unable to connect to update TOI file. Attempting to use existing file...') + logger.warning('Unable to connect to update TOI file. Attempting to use existing file...') self.ctoi_dispositions_ = self.load_ctoi_dispositions_in_project_format() return self.ctoi_dispositions_ @@ -94,7 +96,7 @@ def update_ctoi_dispositions_file(self): Downloads the latest CTOI dispositions file. """ ctoi_csv_url = 'https://exofop.ipac.caltech.edu/tess/download_ctoi.php?sort=ctoi&output=csv' - response = requests.get(ctoi_csv_url) + response = requests.get(ctoi_csv_url, timeout=600) with self.ctoi_dispositions_path.open('wb') as csv_file: csv_file.write(response.content) @@ -165,11 +167,11 @@ def print_exofop_toi_and_ctoi_planet_dispositions_for_tic_target(self, tic_id): """ dispositions_data_frame = self.retrieve_exofop_toi_and_ctoi_planet_disposition_for_tic_id(tic_id) if dispositions_data_frame.shape[0] == 0: - print('No known ExoFOP dispositions found.') + logger.info('No known ExoFOP dispositions found.') return # Use context options to not truncate printed data. with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None): - print(dispositions_data_frame) + logger.info(dispositions_data_frame) def download_exofop_toi_light_curves_to_directory(self, directory: Path): """ @@ -178,13 +180,13 @@ def download_exofop_toi_light_curves_to_directory(self, directory: Path): :param directory: The directory to download the light curves to. Defaults to the data interface directory. """ - print("Downloading ExoFOP TOI disposition CSV...") + logger.info("Downloading ExoFOP TOI disposition CSV...") if isinstance(directory, str): directory = Path(directory) tic_ids = self.toi_dispositions[ToiColumns.tic_id.value].unique() - print('Downloading TESS observation list...') + logger.info('Downloading TESS observation list...') single_sector_observations = get_all_two_minute_single_sector_observations(tic_ids) - print("Downloading light curves which are confirmed or suspected planets in TOI dispositions...") + logger.info("Downloading light curves which are confirmed or suspected planets in TOI dispositions...") suspected_planet_dispositions = self.toi_dispositions[ self.toi_dispositions[ToiColumns.disposition.value] != 'FP'] suspected_planet_observations = pd.merge(single_sector_observations, suspected_planet_dispositions, how='inner', @@ -194,7 +196,7 @@ def download_exofop_toi_light_curves_to_directory(self, directory: Path): suspected_planet_data_products['productFilename'].str.endswith('lc.fits')] suspected_planet_download_manifest = download_products( suspected_planet_light_curve_data_products, data_directory=self.data_directory) - print(f'Verifying and moving light curves to {directory}...') + logger.info(f'Verifying and moving light curves to {directory}...') directory.mkdir(parents=True, exist_ok=True) for _row_index, row in suspected_planet_download_manifest.iterrows(): if row['Status'] == 'COMPLETE': diff --git a/src/ramjet/data_interface/tess_transit_metadata_manager.py b/src/ramjet/data_interface/tess_transit_metadata_manager.py index 619bbcad..dbd5273c 100644 --- a/src/ramjet/data_interface/tess_transit_metadata_manager.py +++ b/src/ramjet/data_interface/tess_transit_metadata_manager.py @@ -2,8 +2,8 @@ Code for managing the TESS transit metadata. """ import contextlib +import logging import sqlite3 -import warnings from enum import Enum import pandas as pd @@ -13,6 +13,8 @@ from ramjet.data_interface.tess_toi_data_interface import TessToiDataInterface, ToiColumns from ramjet.database.tess_planet_disposition import TessPlanetDisposition +logger = logging.getLogger(__name__) + class Disposition(Enum): """ @@ -41,12 +43,13 @@ class TessTransitMetadataManager: """ A class for managing the TESS transit metadata. """ + @staticmethod def build_table(): """ Builds the TESS transit metadata table. """ - print('Building TESS transit metadata table...') + logger.info('Building TESS transit metadata table...') tess_toi_data_interface = TessToiDataInterface() toi_dispositions = tess_toi_data_interface.toi_dispositions ctoi_dispositions = tess_toi_data_interface.ctoi_dispositions @@ -54,7 +57,7 @@ def build_table(): ctoi_filtered_dispositions = ctoi_dispositions.filter([ToiColumns.tic_id.value, ToiColumns.disposition.value]) all_dispositions = pd.concat([toi_filtered_dispositions, ctoi_filtered_dispositions], ignore_index=True) target_grouped_dispositions = all_dispositions.groupby(ToiColumns.tic_id.value)[ToiColumns.disposition.value - ].apply(set) + ].apply(set) row_count = 0 metadatabase.drop_tables([TessTransitMetadata]) metadatabase.create_tables([TessTransitMetadata]) @@ -68,13 +71,13 @@ def build_table(): elif 'FP' in disposition_set or 'FA' in disposition_set: database_disposition = Disposition.FALSE_POSITIVE.value else: - warnings.warn(f'Dispositions for TIC {tic_id} are {disposition_set}, which does not contain a known' - f'disposition.') + logger.warning(f'Dispositions for TIC {tic_id} are {disposition_set}, which does not contain' + f' a known disposition.') continue row = TessTransitMetadata(tic_id=tic_id, disposition=database_disposition) row.save() row_count += 1 - print(f'Table built. {row_count} rows added.') + logger.info(f'Table built. {row_count} rows added.') @staticmethod def add_tic_ids_as_confirmed(tic_ids: list[int]): @@ -94,7 +97,7 @@ def add_tic_ids_as_confirmed(tic_ids: list[int]): transit.disposition = Disposition.CONFIRMED.value transit.save() rows_added += 1 - print(f'{rows_added} rows added.') + logger.info(f'{rows_added} rows added.') if __name__ == '__main__': diff --git a/src/ramjet/data_interface/tess_two_minute_cadence_light_curve_metadata_manager.py b/src/ramjet/data_interface/tess_two_minute_cadence_light_curve_metadata_manager.py index 3a180031..f932b5cc 100644 --- a/src/ramjet/data_interface/tess_two_minute_cadence_light_curve_metadata_manager.py +++ b/src/ramjet/data_interface/tess_two_minute_cadence_light_curve_metadata_manager.py @@ -1,6 +1,7 @@ """ Code for managing the meta data of the two minute cadence TESS light curves. """ +import logging from pathlib import Path from peewee import CharField, IntegerField, SchemaManager @@ -14,6 +15,8 @@ ) from ramjet.data_interface.tess_data_interface import get_tic_id_and_sector_from_file_path +logger = logging.getLogger(__name__) + class TessTwoMinuteCadenceLightCurveMetadata(MetadatabaseModel): """ @@ -62,7 +65,7 @@ def populate_sql_database(self): """ Populates the SQL database based on the light curve files. """ - print('Populating the TESS two minute cadence light curve meta data table...') + logger.info('Populating the TESS two minute cadence light curve meta data table...') path_glob = self.light_curve_root_directory_path.glob('**/*.fits') row_count = 0 batch_paths = [] @@ -73,10 +76,10 @@ def populate_sql_database(self): if index % 1000 == 0 and index != 0: self.insert_multiple_rows_from_paths_into_database(batch_paths) batch_paths = [] - print(f'{index} rows inserted...', end='\r') + logger.info(f'{index} rows inserted...') if len(batch_paths) > 0: self.insert_multiple_rows_from_paths_into_database(batch_paths) - print(f'TESS two minute cadence light curve meta data table populated. {row_count} rows added.') + logger.info(f'TESS two minute cadence light curve meta data table populated. {row_count} rows added.') def build_table(self): """ @@ -86,7 +89,7 @@ def build_table(self): TessTwoMinuteCadenceLightCurveMetadata.create_table() SchemaManager(TessTwoMinuteCadenceLightCurveMetadata).drop_indexes() # To allow for fast insert. self.populate_sql_database() - print('Building indexes...') + logger.info('Building indexes...') SchemaManager(TessTwoMinuteCadenceLightCurveMetadata).create_indexes() # Since we dropped them before. diff --git a/src/ramjet/database/tess_planet_disposition.py b/src/ramjet/database/tess_planet_disposition.py index 1b2683d7..2c5c0cc2 100644 --- a/src/ramjet/database/tess_planet_disposition.py +++ b/src/ramjet/database/tess_planet_disposition.py @@ -14,7 +14,7 @@ class Disposition(Enum): """ An enum to represent the possible planet dispositions. """ - PASS = 'Pass' + PASS = 'Pass' # noqa S105 : False positive assuming field is a password field. CONDITIONAL = 'Conditional' AMBIGUOUS = 'Ambiguous' UNLIKELY = 'Unlikely' @@ -35,7 +35,7 @@ class TessPlanetDisposition(BaseModel): """ A database model for the database entity of a TESS planet disposition. """ - id = AutoField() + id = AutoField() # noqa A003 transiter: TessTransiter = ForeignKeyField(TessTransiter) disposition = CharField(choices=Disposition) source = CharField(choices=Source) diff --git a/src/ramjet/database/tess_target.py b/src/ramjet/database/tess_target.py index 3ff0be25..aa090980 100644 --- a/src/ramjet/database/tess_target.py +++ b/src/ramjet/database/tess_target.py @@ -11,5 +11,5 @@ class TessTarget(BaseModel): """ A model for the TESS target database table. """ - id = AutoField() + id = AutoField() # noqa A003 tic_id = IntegerField(index=True, unique=True) diff --git a/src/ramjet/database/tess_transiter.py b/src/ramjet/database/tess_transiter.py index 2bb4ddaf..0f7345bd 100644 --- a/src/ramjet/database/tess_transiter.py +++ b/src/ramjet/database/tess_transiter.py @@ -11,7 +11,7 @@ class TessTransiter(BaseModel): """ A database model for the database entity of a TESS transiter. """ - id = AutoField() + id = AutoField() # noqa A003 target: TessTarget = ForeignKeyField(TessTarget) radius__solar_radii = FloatField(null=True) has_known_contamination_ratio = BooleanField(default=True) diff --git a/src/ramjet/logging/wandb_logger.py b/src/ramjet/logging/wandb_logger.py index bc249587..4943cdd3 100644 --- a/src/ramjet/logging/wandb_logger.py +++ b/src/ramjet/logging/wandb_logger.py @@ -7,13 +7,16 @@ import multiprocessing import queue from abc import ABC, abstractmethod +from typing import TYPE_CHECKING import plotly import plotly.graph_objects as go from plotly.subplots import make_subplots import wandb -from ramjet.photometric_database.light_curve import LightCurve + +if TYPE_CHECKING: + from ramjet.photometric_database.light_curve import LightCurve class ExampleRequest: @@ -53,6 +56,7 @@ class WandbLoggableLightCurve(WandbLoggable): """ A wandb loggable light curve. """ + def __init__(self, light_curve_name: str, light_curve: LightCurve): super().__init__() self.light_curve_name: str = light_curve_name @@ -78,6 +82,7 @@ class WandbLoggableInjection(WandbLoggable): """ A wandb loggable containing logging data for injecting a signal into a light curve. """ + def __init__(self): super().__init__() self.injectee_name: str | None = None @@ -133,8 +138,6 @@ class WandbLogger: """ A class to log to wandb. """ - loggable_types = [LightCurve] - def __init__(self): manager = multiprocessing.Manager() self.lock = manager.Lock() @@ -162,8 +165,7 @@ def process_py_mapper_example_queues(self, epoch: int) -> None: if isinstance(queue_item, WandbLoggable): queue_item.log(example_queue_name, epoch) else: - msg = f"{queue_item} is not a handled logger type." - raise ValueError(msg) + raise TypeError(f"{queue_item} is not a handled logger type.") except queue.Empty: break @@ -181,7 +183,9 @@ def create_request_queue_for_collection(self, name: str) -> multiprocessing.Queu :param name: The name of the queue. :return: The queue. """ - assert name not in self.request_queues + if name in self.request_queues: + error_message = f'Trying to create queue {name}, but is already exists in the request queues.' + raise ValueError(error_message) manager = multiprocessing.Manager() queue_ = manager.Queue() self.request_queues[name] = queue_ @@ -194,7 +198,9 @@ def create_example_queue_for_collection(self, name: str) -> multiprocessing.Queu :param name: The name of the queue. :return: The queue. """ - assert name not in self.example_queues + if name in self.example_queues: + error_message = f'Trying to create queue {name}, but is already exists in the example queues.' + raise ValueError(error_message) manager = multiprocessing.Manager() queue_ = manager.Queue() self.example_queues[name] = queue_ @@ -210,9 +216,9 @@ def should_produce_example(request_queue: multiprocessing.Queue) -> bool: """ try: request_queue.get(block=False) - return True except queue.Empty: return False + return True @staticmethod def submit_loggable(example_queue: multiprocessing.Queue, loggable: WandbLoggable) -> None: diff --git a/src/ramjet/photometric_database/derived/moa_microlensing_light_curve_collection.py b/src/ramjet/photometric_database/derived/moa_microlensing_light_curve_collection.py index 19e63070..52b7fc89 100644 --- a/src/ramjet/photometric_database/derived/moa_microlensing_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/moa_microlensing_light_curve_collection.py @@ -149,7 +149,7 @@ def get_paths(self) -> Iterable[Path]: return [Path('')] - def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray): + def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray): # noqa ARG002 """ Loads the times and magnifications from a random generated signal. diff --git a/src/ramjet/photometric_database/derived/moa_survey_balanced_tag_database.py b/src/ramjet/photometric_database/derived/moa_survey_balanced_tag_database.py index 3e8e40cf..386f585f 100644 --- a/src/ramjet/photometric_database/derived/moa_survey_balanced_tag_database.py +++ b/src/ramjet/photometric_database/derived/moa_survey_balanced_tag_database.py @@ -1,4 +1,4 @@ -from typing import Union +from __future__ import annotations from ramjet.data_interface.moa_data_interface import MoaDataInterface from ramjet.photometric_database.derived.moa_survey_light_curve_collection import MoaSurveyLightCurveCollection @@ -21,7 +21,7 @@ def __init__(self): self.validation_standard_light_curve_collections = self.create_collection_for_each_tag(dataset_splits=[8]) self.inference_light_curve_collections = self.create_collection_for_each_tag(dataset_splits=[9]) - def create_collection_for_each_tag(self, dataset_splits: Union[list[int], None] + def create_collection_for_each_tag(self, dataset_splits: list[int] | None ) -> list[MoaSurveyLightCurveCollection]: """ Creates a light curve collection for each tag in the survey and assigns the appropriate labels. diff --git a/src/ramjet/photometric_database/derived/moa_survey_light_curve_collection.py b/src/ramjet/photometric_database/derived/moa_survey_light_curve_collection.py index e1399122..9b60c51b 100644 --- a/src/ramjet/photometric_database/derived/moa_survey_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/moa_survey_light_curve_collection.py @@ -1,11 +1,11 @@ +from __future__ import annotations + import re import shutil import socket -from collections.abc import Iterable from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING -import numpy as np import pandas as pd import scipy.stats from filelock import FileLock @@ -13,6 +13,11 @@ from ramjet.data_interface.moa_data_interface import MoaDataInterface from ramjet.photometric_database.light_curve_collection import LightCurveCollection +if TYPE_CHECKING: + from collections.abc import Iterable + + import numpy as np + class MoaSurveyLightCurveCollection(LightCurveCollection): """ @@ -20,12 +25,12 @@ class MoaSurveyLightCurveCollection(LightCurveCollection): """ moa_data_interface = MoaDataInterface() - def __init__(self, survey_tags: list[str], dataset_splits: Union[list[int], None] = None, - label: Union[float, list[float], np.ndarray, None] = None): + def __init__(self, survey_tags: list[str], dataset_splits: list[int] | None = None, + label: float | list[float] | np.ndarray | None = None): super().__init__() self.label = label self.survey_tags: list[str] = survey_tags - self.dataset_splits: Union[list[int], None] = dataset_splits + self.dataset_splits: list[int] | None = dataset_splits def get_paths(self) -> Iterable[Path]: """ @@ -56,8 +61,7 @@ def move_path_to_nvme(self, path: Path) -> Path: shutil.copy(path, nvme_tmp_path) nvme_tmp_path.rename(nvme_path) return nvme_path - else: - return path + return path def load_times_and_fluxes_from_path(self, path: Path) -> (np.ndarray, np.ndarray): diff --git a/src/ramjet/photometric_database/derived/self_lensing_binary_synthetic_signals_light_curve_collection.py b/src/ramjet/photometric_database/derived/self_lensing_binary_synthetic_signals_light_curve_collection.py index aa33917f..ff46b350 100644 --- a/src/ramjet/photometric_database/derived/self_lensing_binary_synthetic_signals_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/self_lensing_binary_synthetic_signals_light_curve_collection.py @@ -1,6 +1,7 @@ """ Code for a light curve collection of Agnieszka Cieplak's synthetic signals. """ +import logging import re import tarfile import urllib.request @@ -12,6 +13,8 @@ from ramjet.photometric_database.light_curve_collection import LightCurveCollection +logger = logging.getLogger(__name__) + class SelfLensingBinarySyntheticSignalsLightCurveCollection(LightCurveCollection): """ @@ -27,7 +30,7 @@ def download_csv_files(self): """ Downloads Agnieszka Cieplak's synthetic signals in their original CSV form. """ - print('Downloading synthetic signal CSV files...') + logger.info('Downloading synthetic signal CSV files...') tar_file_path = self.data_directory.joinpath('synthetic_signals_csv_files.tar') urllib.request.urlretrieve('https://api.onedrive.com/v1.0/shares/s!AjiSFm1N8Bv7ghXushB7JOzABXdv/root/content', str(tar_file_path)) @@ -43,7 +46,7 @@ def convert_csv_files_to_project_format(self): """ Converts Agnieszka Cieplak's synthetic signal CSV files to the project format feather files. """ - print('Converting synthetic signals to project format...') + logger.info('Converting synthetic signals to project format...') out_paths = self.data_directory.glob('*.out') synthetic_signal_csv_paths = [path for path in out_paths if re.match(r'lc_\d+\.out', path.name)] for synthetic_signal_csv_path in synthetic_signal_csv_paths: @@ -116,4 +119,3 @@ def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np light_curve_collection.data_directory.mkdir(parents=True, exist_ok=True) light_curve_collection.download_csv_files() light_curve_collection.convert_csv_files_to_project_format() - print('Self lensing binary synthetic signal light curve collection ready.') diff --git a/src/ramjet/photometric_database/derived/siddhant_solanki_heart_beat_synthetic_signals_collection.py b/src/ramjet/photometric_database/derived/siddhant_solanki_heart_beat_synthetic_signals_collection.py index 990b76c1..5144b09d 100644 --- a/src/ramjet/photometric_database/derived/siddhant_solanki_heart_beat_synthetic_signals_collection.py +++ b/src/ramjet/photometric_database/derived/siddhant_solanki_heart_beat_synthetic_signals_collection.py @@ -1,6 +1,6 @@ -import re +from __future__ import annotations -from peewee import Select +import re from ramjet.data_interface.tess_ffi_light_curve_metadata_manager import TessFfiLightCurveMetadata from ramjet.photometric_database.derived.tess_ffi_light_curve_collection import TessFfiLightCurveCollection @@ -10,15 +10,19 @@ from enum import StrEnum except ImportError: from backports.strenum import StrEnum -from collections.abc import Iterable from pathlib import Path -from typing import Union +from typing import TYPE_CHECKING import numpy as np import pandas as pd from ramjet.photometric_database.light_curve_collection import LightCurveCollection +if TYPE_CHECKING: + from collections.abc import Iterable + + from peewee import Select + class ColumnName(StrEnum): TIME__DAYS = 'time__days' @@ -34,7 +38,7 @@ def __init__(self): def get_paths(self) -> Iterable[Path]: all_synthetic_signal_paths = self.data_directory.glob('*.txt') heart_beat_synthetic_signals = [path for path in all_synthetic_signal_paths - if re.match(r'generated_lc_\d+.txt', path.name) is not None] + if re.match(r'generated_lc_\d+.txt', path.name) is not None] return heart_beat_synthetic_signals def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np.ndarray): @@ -44,7 +48,10 @@ def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np magnifications = synthetic_signal_data_frame[ColumnName.MAGNIFICATION].values step_size__days = 0.0069444444 times = np.arange(0, magnifications.shape[0] * step_size__days, step_size__days) - assert times.shape[0] == magnifications.shape[0] + if times.shape != magnifications.shape: + error_message = f'Times and magnifications arrays must have the same shape, but have shapes ' \ + f'{times.shape} and {magnifications.shape}.' + raise ValueError(error_message) return times, magnifications @@ -67,18 +74,23 @@ def load_times_and_magnifications_from_path(self, path: Path) -> (np.ndarray, np magnifications = synthetic_signal_data_frame[ColumnName.MAGNIFICATION].values step_size__days = 0.0069444444 times = np.arange(0, magnifications.shape[0] * step_size__days, step_size__days) - assert times.shape[0] == magnifications.shape[0] + if times.shape != magnifications.shape: + error_message = f'Times and magnifications arrays must have the same shape, but have shapes ' \ + f'{times.shape} and {magnifications.shape}.' + raise ValueError(error_message) return times, magnifications + class TessFfiHeartBeatHardNegativeLightcurveCollection(TessFfiLightCurveCollection): """ A class representing the collection of TESS two minute cadence lightcurves containing eclipsing binaries. """ - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) self.label = 0 - self.hard_negative_ids = list(pd.read_csv('data/heart_beat_hard_negatives.csv')['tic_id'].values) + self.hard_negative_ids = pd.read_csv('data/heart_beat_hard_negatives.csv')['tic_id'].values.tolist() def get_sql_query(self) -> Select: """ diff --git a/src/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_light_curve_collection.py b/src/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_light_curve_collection.py index 03c29ff6..5169ba69 100644 --- a/src/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_light_curve_collection.py @@ -1,9 +1,9 @@ """ Code representing the collection of TESS two minute cadence light curves containing eclipsing binaries. """ -from typing import Union +from __future__ import annotations -from peewee import Select +from typing import TYPE_CHECKING from ramjet.data_interface.tess_eclipsing_binary_metadata_manager import TessEclipsingBinaryMetadata from ramjet.data_interface.tess_ffi_light_curve_metadata_manager import TessFfiLightCurveMetadata @@ -11,13 +11,16 @@ from ramjet.data_interface.tess_transit_metadata_manager import TessTransitMetadata from ramjet.photometric_database.derived.tess_ffi_light_curve_collection import TessFfiLightCurveCollection +if TYPE_CHECKING: + from peewee import Select + class TessFfiEclipsingBinaryLightCurveCollection(TessFfiLightCurveCollection): """ A class representing the collection of TESS two minute cadence light curves containing eclipsing binaries. """ - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) self.label = 1 @@ -38,8 +41,8 @@ class TessFfiAntiEclipsingBinaryForTransitLightCurveCollection(TessFfiLightCurve A class representing the collection of TESS two minute cadence light curves flagged as eclipsing binaries which are not a suspected transit. """ - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) self.label = 0 diff --git a/src/ramjet/photometric_database/derived/tess_ffi_light_curve_collection.py b/src/ramjet/photometric_database/derived/tess_ffi_light_curve_collection.py index 3c3b14c2..64bf39c7 100644 --- a/src/ramjet/photometric_database/derived/tess_ffi_light_curve_collection.py +++ b/src/ramjet/photometric_database/derived/tess_ffi_light_curve_collection.py @@ -1,13 +1,11 @@ """ Code for a light curve collection of the TESS FFI data, as produced by Brian Powell. """ -from pathlib import Path -from typing import Union +from __future__ import annotations -import numpy as np -from peewee import Select +from pathlib import Path +from typing import TYPE_CHECKING -from ramjet.data_interface.metadatabase import MetadatabaseModel from ramjet.data_interface.tess_ffi_light_curve_metadata_manager import ( TessFfiLightCurveMetadata, TessFfiLightCurveMetadataManager, @@ -15,6 +13,12 @@ from ramjet.photometric_database.sql_metadata_light_curve_collection import SqlMetadataLightCurveCollection from ramjet.photometric_database.tess_ffi_light_curve import TessFfiLightCurve +if TYPE_CHECKING: + import numpy as np + from peewee import Select + + from ramjet.data_interface.metadatabase import MetadatabaseModel + class TessFfiLightCurveCollection(SqlMetadataLightCurveCollection): """ @@ -22,13 +26,13 @@ class TessFfiLightCurveCollection(SqlMetadataLightCurveCollection): """ tess_ffi_light_curve_metadata_manger = TessFfiLightCurveMetadataManager() - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__() self.data_directory: Path = Path('data/tess_ffi_light_curves') self.label = 0 - self.dataset_splits: Union[list[int], None] = dataset_splits - self.magnitude_range: (Union[float, None], Union[float, None]) = magnitude_range + self.dataset_splits: list[int] | None = dataset_splits + self.magnitude_range: (float | None, float | None) = magnitude_range def get_sql_query(self) -> Select: """ diff --git a/src/ramjet/photometric_database/derived/tess_ffi_transit_light_curve_collections.py b/src/ramjet/photometric_database/derived/tess_ffi_transit_light_curve_collections.py index 8e34471b..34c8daf4 100644 --- a/src/ramjet/photometric_database/derived/tess_ffi_transit_light_curve_collections.py +++ b/src/ramjet/photometric_database/derived/tess_ffi_transit_light_curve_collections.py @@ -1,21 +1,24 @@ """ Code representing the collection of TESS two minute cadence light curves containing transits. """ -from typing import Union +from __future__ import annotations -from peewee import Select +from typing import TYPE_CHECKING from ramjet.data_interface.tess_ffi_light_curve_metadata_manager import TessFfiLightCurveMetadata from ramjet.data_interface.tess_transit_metadata_manager import Disposition, TessTransitMetadata from ramjet.photometric_database.derived.tess_ffi_light_curve_collection import TessFfiLightCurveCollection +if TYPE_CHECKING: + from peewee import Select + class TessFfiConfirmedTransitLightCurveCollection(TessFfiLightCurveCollection): """ A class representing the collection of TESS two minute cadence light curves containing transits. """ - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) self.label = 1 @@ -37,8 +40,8 @@ class TessFfiConfirmedAndCandidateTransitLightCurveCollection(TessFfiLightCurveC A class representing the collection of TESS two minute cadence light curves containing transits. """ - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) self.label = 1 @@ -61,8 +64,8 @@ class TessFfiNonTransitLightCurveCollection(TessFfiLightCurveCollection): A class representing the collection of TESS two minute cadence light curves containing transits. """ - def __init__(self, dataset_splits: Union[list[int], None] = None, - magnitude_range: (Union[float, None], Union[float, None]) = (None, None)): + def __init__(self, dataset_splits: list[int] | None = None, + magnitude_range: (float | None, float | None) = (None, None)): super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range) self.label = 0