Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use cached_property and types #1718

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 94 additions & 54 deletions src/pint/observatory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,25 @@
necessary.
"""

from copy import deepcopy
import os
import textwrap
from collections import defaultdict
from collections.abc import Callable
from copy import deepcopy
from io import StringIO
from pathlib import Path
from typing import Optional, Union, List, Dict

import astropy.coordinates
import astropy.time
import astropy.units as u
import numpy as np
from astropy.coordinates import EarthLocation
from loguru import logger as log

from pint.config import runtimefile
from pint.pulsar_mjd import Time
from pint.utils import interesting_lines
from pint.utils import interesting_lines, PosVel

# Include any files that define observatories here. This will start
# with the standard distribution files, then will read any system- or
Expand Down Expand Up @@ -87,7 +90,7 @@ class ClockCorrectionOutOfRange(ClockCorrectionError):
_bipm_clock_versions = {}


def _load_gps_clock():
def _load_gps_clock() -> None:
global _gps_clock
if _gps_clock is None:
log.info("Loading global GPS clock file")
Expand All @@ -97,7 +100,7 @@ def _load_gps_clock():
)


def _load_bipm_clock(bipm_version):
def _load_bipm_clock(bipm_version: str) -> None:
bipm_version = bipm_version.lower()
if bipm_version not in _bipm_clock_versions:
try:
Expand Down Expand Up @@ -136,34 +139,43 @@ class Observatory:
position.
"""

fullname: str
"""Full human-readable name of the observatory."""
include_gps: bool
"""Whether to include GPS clock corrections."""
include_bipm: bool
"""Whether to include BIPM clock corrections."""
bipm_version: str
"""Version of the BIPM clock file to use."""

# This is a dict containing all defined Observatory instances,
# keyed on standard observatory name.
_registry = {}
_registry: Dict[str, "Observatory"] = {}

# This is a dict mapping any defined aliases to the corresponding
# standard name.
_alias_map = {}
_alias_map: Dict[str, str] = {}

def __init__(
self,
name,
fullname=None,
aliases=None,
include_gps=True,
include_bipm=True,
bipm_version=bipm_default,
overwrite=False,
name: str,
fullname: Optional[str] = None,
aliases: Optional[List[str]] = None,
include_gps: bool = True,
include_bipm: bool = True,
bipm_version: str = bipm_default,
overwrite: bool = False,
):
self._name = name.lower()
self._aliases = (
self._name: str = name.lower()
self._aliases: List[str] = (
list(set(map(str.lower, aliases))) if aliases is not None else []
)
if aliases is not None:
Observatory._add_aliases(self, aliases)
self.fullname = fullname if fullname is not None else name
self.include_gps = include_gps
self.include_bipm = include_bipm
self.bipm_version = bipm_version
self.fullname: str = fullname if fullname is not None else name
self.include_gps: bool = include_gps
self.include_bipm: bool = include_bipm
self.bipm_version: str = bipm_version

if name.lower() in Observatory._registry:
if not overwrite:
Expand All @@ -175,16 +187,18 @@ def __init__(
Observatory._register(self, name)

@classmethod
def _register(cls, obs, name):
"""Add an observatory to the registry using the specified name
(which will be converted to lower case). If an existing observatory
def _register(cls, obs: "Observatory", name: str) -> None:
"""Add an observatory to the registry using the specified name (which will be converted to lower case).

If an existing observatory
of the same name exists, it will be replaced with the new one.
The Observatory instance's name attribute will be updated for
consistency."""
consistency.
"""
cls._registry[name.lower()] = obs

@classmethod
def _add_aliases(cls, obs, aliases):
def _add_aliases(cls, obs: "Observatory", aliases: List[str]) -> None:
"""Add aliases for the specified Observatory. Aliases
should be given as a list. If any of the new aliases are already in
use, they will be replaced. Aliases are not checked against the
Expand All @@ -196,14 +210,17 @@ def _add_aliases(cls, obs, aliases):
cls._alias_map[a.lower()] = obs.name

@staticmethod
def gps_correction(t, limits="warn"):
def gps_correction(t: astropy.time.Time, limits: str = "warn") -> u.Quantity:
"""Compute the GPS clock corrections for times t."""
log.info("Applying GPS to UTC clock correction (~few nanoseconds)")
_load_gps_clock()
assert _gps_clock is not None
return _gps_clock.evaluate(t, limits=limits)

@staticmethod
def bipm_correction(t, bipm_version=bipm_default, limits="warn"):
def bipm_correction(
t: astropy.time.Time, bipm_version: str = bipm_default, limits: str = "warn"
) -> u.Quantity:
"""Compute the GPS clock corrections for times t."""
log.info(f"Applying TT(TAI) to TT({bipm_version}) clock correction (~27 us)")
tt2tai = 32.184 * 1e6 * u.us
Expand All @@ -214,7 +231,7 @@ def bipm_correction(t, bipm_version=bipm_default, limits="warn"):
)

@classmethod
def clear_registry(cls):
def clear_registry(cls) -> None:
"""Clear registry for ground-based observatories."""
cls._registry = {}
cls._alias_map = {}
Expand All @@ -229,7 +246,7 @@ def names(cls):
return cls._registry.keys()

@classmethod
def names_and_aliases(cls):
def names_and_aliases(cls) -> Dict[str, List[str]]:
"""List all observatories and their aliases"""
import pint.observatory.topo_obs # noqa
import pint.observatory.special_locations # noqa
Expand All @@ -241,15 +258,24 @@ def names_and_aliases(cls):
# setter methods that update the registries appropriately.

@property
def name(self):
def name(self) -> str:
"""Short name of the observatory.

This is the name used in TOA files and in the observatory registry.
"""
return self._name

@property
def aliases(self):
def aliases(self) -> List[str]:
"""List of aliases for the observatory.

These are short names also used to specify this observatory.
Includes ITOA and TEMPO codes, and any other common names.
"""
return self._aliases

@classmethod
def get(cls, name):
def get(cls, name: str) -> "Observatory":
"""Returns the Observatory instance for the specified name/alias.

If the name has not been defined, an error will be raised. Aside
Expand Down Expand Up @@ -303,9 +329,12 @@ def get(cls, name):
# Any which raise NotImplementedError below must be implemented in
# derived classes.

def earth_location_itrf(self, time=None):
"""Returns observatory geocentric position as an astropy
EarthLocation object. For observatories where this is not
def earth_location_itrf(
self, time: Optional[astropy.time.Time] = None
) -> Union[None, np.ndarray]:
"""Returns observatory geocentric position as an astropy EarthLocation object.

For observatories where this is not
relevant, None can be returned.

The location is in the International Terrestrial Reference Frame (ITRF).
Expand All @@ -319,8 +348,9 @@ def earth_location_itrf(self, time=None):
"""
return None

def get_gcrs(self, t, ephem=None):
"""Return position vector of observatory in GCRS
def get_gcrs(self, t: astropy.time.Time, ephem=None):
"""Return position vector of observatory in GCRS.

t is an astropy.Time or array of astropy.Time objects
ephem is a link to an ephemeris file. Needed for SSB observatory
Returns a 3-vector of Quantities representing the position
Expand All @@ -329,14 +359,17 @@ def get_gcrs(self, t, ephem=None):
raise NotImplementedError

@property
def timescale(self):
"""Returns the timescale that TOAs from this observatory will be in,
once any clock corrections have been applied. This should be a
def timescale(self) -> str:
"""Returns the timescale that TOAs from this observatory will be in, once any clock corrections have been applied.

This should be a
string suitable to be passed directly to the scale argument of
astropy.time.Time()."""
raise NotImplementedError

def clock_corrections(self, t, limits="warn"):
def clock_corrections(
self, t: astropy.time.Time, limits: str = "warn"
) -> u.Quantity:
"""Compute clock corrections for a Time array.

Given an array-valued Time, return the clock corrections
Expand All @@ -356,7 +389,7 @@ def clock_corrections(self, t, limits="warn"):

return corr

def last_clock_correction_mjd(self):
def last_clock_correction_mjd(self) -> float:
"""Return the MJD of the last available clock correction.

Returns ``np.inf`` if no clock corrections are relevant.
Expand All @@ -365,6 +398,7 @@ def last_clock_correction_mjd(self):

if self.include_gps:
_load_gps_clock()
assert _gps_clock is not None
t = min(t, _gps_clock.last_correction_mjd())
if self.include_bipm:
_load_bipm_clock(self.bipm_version)
Expand All @@ -374,7 +408,13 @@ def last_clock_correction_mjd(self):
)
return t

def get_TDBs(self, t, method="default", ephem=None, options=None):
def get_TDBs(
self,
t: astropy.time.Time,
method: Union[str, Callable] = "default",
ephem: Optional[str] = None,
options: Optional[dict] = None,
):
"""This is a high level function for converting TOAs to TDB time scale.

Different method can be applied to obtain the result. Current supported
Expand Down Expand Up @@ -409,13 +449,13 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
t = Time([t])
if t.scale == "tdb":
return t
# Check the method. This pattern is from numpy minimize
meth = "_custom" if callable(method) else method.lower()
if options is None:
options = {}
if meth == "_custom":
if callable(method):
options = dict(options)
return method(t, **options)
else:
meth = method.lower()
if meth == "default":
return self._get_TDB_default(t, ephem)
elif meth == "ephemeris":
Expand All @@ -428,17 +468,17 @@ def get_TDBs(self, t, method="default", ephem=None, options=None):
else:
raise ValueError(f"Unknown method '{method}'.")

def _get_TDB_default(self, t, ephem):
def _get_TDB_default(self, t: astropy.time.Time, ephem):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does ephem not get a type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add it because I wasn't sure what the type was. String I think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's a string

return t.tdb

def _get_TDB_ephem(self, t, ephem):
def _get_TDB_ephem(self, t: astropy.time.Time, ephem):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and later

"""Read the ephem TDB-TT column.

This column is provided by DE4XXt version of ephemeris.
"""
raise NotImplementedError

def posvel(self, t, ephem, group=None):
def posvel(self, t: astropy.time.Time, ephem, group=None) -> PosVel:
"""Return observatory position and velocity for the given times.

Position is relative to solar system barycenter; times are
Expand All @@ -451,7 +491,7 @@ def posvel(self, t, ephem, group=None):


def get_observatory(
name, include_gps=None, include_bipm=None, bipm_version=bipm_default
name: str, include_gps=None, include_bipm=None, bipm_version: str = bipm_default
):
"""Convenience function to get observatory object with options.

Expand Down Expand Up @@ -491,14 +531,14 @@ def get_observatory(
return Observatory.get(name)


def earth_location_distance(loc1, loc2):
def earth_location_distance(loc1: EarthLocation, loc2: EarthLocation) -> u.Quantity:
"""Compute the distance between two EarthLocations."""
return (
sum((u.Quantity(loc1.to_geocentric()) - u.Quantity(loc2.to_geocentric())) ** 2)
) ** 0.5


def compare_t2_observatories_dat(t2dir=None):
def compare_t2_observatories_dat(t2dir: Optional[str] = None) -> Dict[str, List[Dict]]:
"""Read a tempo2 observatories.dat file and compare with PINT

Produces a report including lines that can be added to PINT's
Expand Down Expand Up @@ -589,7 +629,7 @@ def compare_t2_observatories_dat(t2dir=None):
return report


def compare_tempo_obsys_dat(tempodir=None):
def compare_tempo_obsys_dat(tempodir: Optional[str] = None) -> Dict[str, List[Dict]]:
"""Read a tempo obsys.dat file and compare with PINT.

Produces a report including lines that can be added to PINT's
Expand Down Expand Up @@ -629,8 +669,8 @@ def compare_tempo_obsys_dat(tempodir=None):
y = float(line_io.read(15))
z = float(line_io.read(15))
line_io.read(2)
icoord = line_io.read(1).strip()
icoord = int(icoord) if icoord else 0
icoord_str = line_io.read(1).strip()
icoord = int(icoord_str) if icoord_str else 0
line_io.read(2)
obsnam = line_io.read(20).strip().lower()
tempo_code = line_io.read(1)
Expand Down
Loading
Loading