Skip to content

Commit

Permalink
Correct various ruff issues
Browse files Browse the repository at this point in the history
  • Loading branch information
golmschenk committed Feb 16, 2024
1 parent bd3dc87 commit e83fccc
Show file tree
Hide file tree
Showing 20 changed files with 171 additions and 109 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,8 @@ ignore = [
"S608",
"S301",
"S311",
"EM102", # The improved source readability is worth the loss in readibility of the traceback in my opinion.
"TRY003", # Disabling EM102 makes this rule trigger in areas it shouldn't.
"G004", # The improved source readability is worth the extra string evaluation cost in my opinion.
]
isort.known-first-party = ["qusi", "ramjet"]
38 changes: 24 additions & 14 deletions src/ramjet/data_interface/tess_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def is_common_mast_connection_error(exception: Exception) -> bool:
"""
print(f'Retrying on {exception}...', flush=True)
# TODO: Rename function, as it includes more than just MAST now.
return (isinstance(exception, (AstroQueryTimeoutError, ConnectionResetError, TimeoutError, astroquery.exceptions.RemoteServiceError, lightkurve.search.SearchError, requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError, requests.exceptions.HTTPError, requests.exceptions.ReadTimeout)))
return (isinstance(exception, (
AstroQueryTimeoutError, ConnectionResetError, TimeoutError, astroquery.exceptions.RemoteServiceError,
lightkurve.search.SearchError, requests.exceptions.ChunkedEncodingError, requests.exceptions.ConnectionError,
requests.exceptions.HTTPError, requests.exceptions.ReadTimeout)))


class NoDataProductsFoundException(Exception):
Expand Down Expand Up @@ -314,7 +317,10 @@ def load_fluxes_and_times_from_fits_file(light_curve_path: str | Path,
light_curve = load_light_curve_from_fits_file(light_curve_path)
fluxes = light_curve[flux_type.value]
times = light_curve['TIME']
assert times.shape == fluxes.shape
if times.shape != fluxes.shape:
error_message = f'Times and fluxes arrays must have the same shape, but have shapes ' \
f'{times.shape} and {fluxes.shape}.'
raise ValueError(error_message)
if remove_nans:
# noinspection PyUnresolvedReferences
nan_indexes = np.union1d(np.argwhere(np.isnan(fluxes)), np.argwhere(np.isnan(times)))
Expand All @@ -325,21 +331,24 @@ def load_fluxes_and_times_from_fits_file(light_curve_path: str | Path,

def load_fluxes_flux_errors_and_times_from_fits_file(light_curve_path: str | Path,
flux_type: TessFluxType = TessFluxType.PDCSAP,
remove_nans: bool = True
*, remove_nans: bool = True
) -> (np.ndarray, np.ndarray, np.ndarray):
"""
Extract the flux and time values from a TESS FITS file.
:param light_curve_path: The path to the FITS file.
:param flux_type: The flux type to extract from the FITS file.
:param remove_nans: Whether or not to remove nans.
:param remove_nans: Whether to remove nans.
:return: The flux and times values from the FITS file.
"""
light_curve = load_light_curve_from_fits_file(light_curve_path)
fluxes = light_curve[flux_type.value]
flux_errors = light_curve[flux_type.value + '_ERR']
times = light_curve['TIME']
assert times.shape == fluxes.shape
if times.shape != fluxes.shape:
error_message = f'Times and fluxes arrays must have the same shape, but have shapes ' \
f'{times.shape} and {fluxes.shape}.'
raise ValueError(error_message)
if remove_nans:
# noinspection PyUnresolvedReferences
nan_indexes = np.union1d(np.argwhere(np.isnan(fluxes)), np.union1d(np.argwhere(np.isnan(times)),
Expand All @@ -350,7 +359,7 @@ def load_fluxes_flux_errors_and_times_from_fits_file(light_curve_path: str | Pat
return fluxes, flux_errors, times


def plot_light_curve_from_mast(tic_id: int, sector: int | None = None, exclude_flux_outliers: bool = False,
def plot_light_curve_from_mast(tic_id: int, sector: int | None = None, *, exclude_flux_outliers: bool = False,
base_data_point_size=3):
"""
Downloads and plots a light curve from MAST.
Expand Down Expand Up @@ -454,8 +463,7 @@ def get_variable_data_frame_for_coordinates(coordinates, radius='21s') -> pd.Dat
variable_table_list = Vizier.query_region(coordinates, radius=radius, catalog='B/gcvs/gcvs_cat')
if len(variable_table_list) > 0:
return variable_table_list[0].to_pandas()
else:
return pd.DataFrame()
return pd.DataFrame()


@retry(retry_on_exception=is_common_mast_connection_error, stop_max_attempt_number=10)
Expand Down Expand Up @@ -530,8 +538,10 @@ def get_all_tess_spoc_light_curve_observations_chunk(tic_id: int | list[int]) ->
def get_spoc_tic_id_list_from_mast() -> list[int]:
sector_data_frames: list[pl.DataFrame] = []
for sector_index in itertools.count(1):
response = requests.get(f'https://archive.stsci.edu/hlsps/tess-spoc/target_lists/s{sector_index:04d}.csv')
if response.status_code != 200:
response = requests.get(f'https://archive.stsci.edu/hlsps/tess-spoc/target_lists/s{sector_index:04d}.csv',
timeout=600)
success_code = 200
if response.status_code != success_code:
break
csv_string = response.text[1:] # Remove hashtag from header.
sector_data_frame = pl.read_csv(StringIO(csv_string))
Expand Down Expand Up @@ -606,10 +616,10 @@ def initialize_astroquery():
Catalogs.TIMEOUT = 2000
Catalogs.PAGESIZE = 3000
try: # Temporary fix for astroquery's update of timeout and pagesize locations.
Observations._portal_api_connection.TIMEOUT = 2000
Observations._portal_api_connection.PAGESIZE = 3000
Catalogs._portal_api_connection.TIMEOUT = 2000
Catalogs._portal_api_connection.PAGESIZE = 3000
Observations._portal_api_connection.TIMEOUT = 2000 # noqa SLF001
Observations._portal_api_connection.PAGESIZE = 3000 # noqa SLF001
Catalogs._portal_api_connection.TIMEOUT = 2000 # noqa SLF001
Catalogs._portal_api_connection.PAGESIZE = 3000 # noqa SLF001
except AttributeError:
pass

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Code for managing the TESS eclipsing binary metadata.
"""
import logging
from pathlib import Path

import pandas as pd
Expand All @@ -10,6 +11,8 @@

brian_powell_eclipsing_binary_csv_path = Path('data/tess_eclipsing_binaries/TESS_EB_catalog_23Jun.csv')

logger = logging.getLogger(__name__)


class TessEclipsingBinaryMetadata(MetadatabaseModel):
"""
Expand All @@ -27,7 +30,7 @@ def build_table():
"""
Builds the TESS eclipsing binary metadata table.
"""
print('Building TESS eclipsing binary metadata table...')
logger.info('Building TESS eclipsing binary metadata table...')
eclipsing_binary_data_frame = pd.read_csv(brian_powell_eclipsing_binary_csv_path, usecols=['ID'])
row_count = 0
metadatabase.drop_tables([TessEclipsingBinaryMetadata])
Expand All @@ -45,7 +48,7 @@ def build_table():
with metadatabase.atomic():
TessEclipsingBinaryMetadata.insert_many(rows).execute()
SchemaManager(TessEclipsingBinaryMetadata).create_indexes()
print(f'Table built. {row_count} rows added.')
logger.info(f'Table built. {row_count} rows added.')


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Code for managing the TESS FFI metadata SQL table.
"""
import itertools
import logging
from pathlib import Path

from peewee import CharField, FloatField, IntegerField, SchemaManager
Expand All @@ -15,6 +16,8 @@
)
from ramjet.photometric_database.tess_ffi_light_curve import TessFfiLightCurve

logger = logging.getLogger(__name__)


class TessFfiLightCurveMetadata(MetadatabaseModel):
"""
Expand Down Expand Up @@ -72,7 +75,7 @@ def populate_sql_database(self):
"""
Populates the SQL database based on the light curve files.
"""
print('Populating the TESS FFI light curve meta data table...', flush=True)
logger.info('Populating the TESS FFI light curve meta data table...')
single_sector_path_globs = []
for sector in range(1, 27):
single_sector_path_glob = self.light_curve_root_directory_path.glob(
Expand All @@ -91,10 +94,10 @@ def populate_sql_database(self):
if index % 1000 == 0 and index != 0:
self.insert_multiple_rows_from_paths_into_database(batch_paths)
batch_paths = []
print(f'{index} rows inserted...', end='\r', flush=True)
logger.info(f'{index} rows inserted...')
if len(batch_paths) > 0:
self.insert_multiple_rows_from_paths_into_database(batch_paths)
print(f'TESS FFI light curve meta data table populated. {row_count} rows added.', flush=True)
logger.info(f'TESS FFI light curve meta data table populated. {row_count} rows added.')

def build_table(self):
"""
Expand Down
13 changes: 7 additions & 6 deletions src/ramjet/data_interface/tess_target_metadata_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Code for managing the metadata of the TESS targets.
"""
import logging
from pathlib import Path

from peewee import IntegerField
Expand All @@ -12,7 +13,9 @@
metadatabase,
metadatabase_uuid,
)
from ramjet.data_interface.tess_data_interface import TessDataInterface
from ramjet.data_interface.tess_data_interface import get_tic_id_and_sector_from_file_path

logger = logging.getLogger(__name__)


class TessTargetMetadata(MetadatabaseModel):
Expand All @@ -27,8 +30,6 @@ class TessTargetMetadataManger:
"""
A class for managing the metadata of TESS targets.
"""
tess_data_interface = TessDataInterface()

def __init__(self):
self.light_curve_root_directory_path = Path('data/tess_two_minute_cadence_light_curves')

Expand Down Expand Up @@ -56,7 +57,7 @@ def populate_sql_database(self):
"""
Populates the SQL database based on the light curve files.
"""
print('Populating the TESS target light curve metadata table...')
logger.info('Populating the TESS target light curve metadata table...')
path_glob = self.light_curve_root_directory_path.glob('**/*.fits')
row_count = 0
batch_paths = []
Expand All @@ -69,10 +70,10 @@ def populate_sql_database(self):
row_count += self.insert_multiple_rows_from_paths_into_database(batch_paths)
batch_paths = []
batch_dataset_splits = []
print(f'{row_count} rows inserted...', end='\r')
logger.info(f'{row_count} rows inserted...')
if len(batch_paths) > 0:
row_count += self.insert_multiple_rows_from_paths_into_database(batch_paths)
print(f'TESS target metadata table populated. {row_count} rows added.')
logger.info(f'TESS target metadata table populated. {row_count} rows added.')

def build_table(self):
"""
Expand Down
30 changes: 16 additions & 14 deletions src/ramjet/data_interface/tess_toi_data_interface.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings
from __future__ import annotations

import logging
from enum import Enum
from pathlib import Path
from typing import Union

import pandas as pd
import requests
Expand All @@ -12,6 +13,7 @@
get_product_list,
)

logger = logging.getLogger(__name__)

class ToiColumns(Enum):
"""
Expand Down Expand Up @@ -45,8 +47,8 @@ def __init__(self, data_directory='data/tess_toi'):
self.toi_dispositions_path = self.data_directory.joinpath('toi_dispositions.csv')
self.ctoi_dispositions_path = self.data_directory.joinpath('ctoi_dispositions.csv')
self.light_curves_directory = self.data_directory.joinpath('light_curves')
self.toi_dispositions_: Union[pd.DataFrame, None] = None
self.ctoi_dispositions_: Union[pd.DataFrame, None] = None
self.toi_dispositions_: pd.DataFrame | None = None
self.ctoi_dispositions_: pd.DataFrame | None = None

@property
def toi_dispositions(self):
Expand All @@ -60,7 +62,7 @@ def toi_dispositions(self):
try:
self.update_toi_dispositions_file()
except requests.exceptions.ConnectionError:
warnings.warn('Unable to connect to update TOI file. Attempting to use existing file...')
logger.warning('Unable to connect to update TOI file. Attempting to use existing file...')
self.toi_dispositions_ = self.load_toi_dispositions_in_project_format()
return self.toi_dispositions_

Expand All @@ -69,7 +71,7 @@ def update_toi_dispositions_file(self):
Downloads the latest TOI dispositions file.
"""
toi_csv_url = 'https://exofop.ipac.caltech.edu/tess/download_toi.php?sort=toi&output=csv'
response = requests.get(toi_csv_url)
response = requests.get(toi_csv_url, timeout=600)
with self.toi_dispositions_path.open('wb') as csv_file:
csv_file.write(response.content)

Expand All @@ -85,7 +87,7 @@ def ctoi_dispositions(self):
try:
self.update_ctoi_dispositions_file()
except requests.exceptions.ConnectionError:
warnings.warn('Unable to connect to update TOI file. Attempting to use existing file...')
logger.warning('Unable to connect to update TOI file. Attempting to use existing file...')
self.ctoi_dispositions_ = self.load_ctoi_dispositions_in_project_format()
return self.ctoi_dispositions_

Expand All @@ -94,7 +96,7 @@ def update_ctoi_dispositions_file(self):
Downloads the latest CTOI dispositions file.
"""
ctoi_csv_url = 'https://exofop.ipac.caltech.edu/tess/download_ctoi.php?sort=ctoi&output=csv'
response = requests.get(ctoi_csv_url)
response = requests.get(ctoi_csv_url, timeout=600)
with self.ctoi_dispositions_path.open('wb') as csv_file:
csv_file.write(response.content)

Expand Down Expand Up @@ -165,11 +167,11 @@ def print_exofop_toi_and_ctoi_planet_dispositions_for_tic_target(self, tic_id):
"""
dispositions_data_frame = self.retrieve_exofop_toi_and_ctoi_planet_disposition_for_tic_id(tic_id)
if dispositions_data_frame.shape[0] == 0:
print('No known ExoFOP dispositions found.')
logger.info('No known ExoFOP dispositions found.')
return
# Use context options to not truncate printed data.
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', None):
print(dispositions_data_frame)
logger.info(dispositions_data_frame)

def download_exofop_toi_light_curves_to_directory(self, directory: Path):
"""
Expand All @@ -178,13 +180,13 @@ def download_exofop_toi_light_curves_to_directory(self, directory: Path):
:param directory: The directory to download the light curves to. Defaults to the data interface directory.
"""
print("Downloading ExoFOP TOI disposition CSV...")
logger.info("Downloading ExoFOP TOI disposition CSV...")
if isinstance(directory, str):
directory = Path(directory)
tic_ids = self.toi_dispositions[ToiColumns.tic_id.value].unique()
print('Downloading TESS observation list...')
logger.info('Downloading TESS observation list...')
single_sector_observations = get_all_two_minute_single_sector_observations(tic_ids)
print("Downloading light curves which are confirmed or suspected planets in TOI dispositions...")
logger.info("Downloading light curves which are confirmed or suspected planets in TOI dispositions...")
suspected_planet_dispositions = self.toi_dispositions[
self.toi_dispositions[ToiColumns.disposition.value] != 'FP']
suspected_planet_observations = pd.merge(single_sector_observations, suspected_planet_dispositions, how='inner',
Expand All @@ -194,7 +196,7 @@ def download_exofop_toi_light_curves_to_directory(self, directory: Path):
suspected_planet_data_products['productFilename'].str.endswith('lc.fits')]
suspected_planet_download_manifest = download_products(
suspected_planet_light_curve_data_products, data_directory=self.data_directory)
print(f'Verifying and moving light curves to {directory}...')
logger.info(f'Verifying and moving light curves to {directory}...')
directory.mkdir(parents=True, exist_ok=True)
for _row_index, row in suspected_planet_download_manifest.iterrows():
if row['Status'] == 'COMPLETE':
Expand Down
Loading

0 comments on commit e83fccc

Please sign in to comment.