This repository has been archived by the owner on Aug 6, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Nathaniel Starkman (@nstarman) <nstarkman@protonmail.com>
- Loading branch information
Showing
8 changed files
with
433 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
"""Common code.""" | ||
|
||
|
||
# __all__ = [ | ||
# ] | ||
|
||
|
||
############################################################################## | ||
# IMPORTS | ||
|
||
# BUILT-IN | ||
import typing as T | ||
|
||
# THIRD PARTY | ||
import astropy.coordinates as coord | ||
import astropy.units as u | ||
|
||
############################################################################## | ||
# PARAMETERS | ||
|
||
EllipsisType = type(Ellipsis) | ||
|
||
UnitType = T.TypeVar("Unit", bound=u.UnitBase) | ||
QuantityType = T.TypeVar("Quantity", bound=u.Quantity) | ||
|
||
|
||
FrameType = T.TypeVar("CoordinateFrame", bound=coord.BaseCoordinateFrame) | ||
SkyCoordType = T.TypeVar("SkyCoord", bound=coord.SkyCoord) | ||
CoordinateType = T.Union[FrameType, SkyCoordType] | ||
|
||
FrameLikeType = T.Union[FrameType, SkyCoordType, str] | ||
|
||
############################################################################## | ||
# END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# -*- coding: utf-8 -*- | ||
# see LICENSE.rst | ||
|
||
"""core.""" | ||
|
||
__all__ = [] | ||
|
||
|
||
############################################################################## | ||
# IMPORTS | ||
# flatten structure | ||
|
||
# PROJECT-SPECIFIC | ||
from . import fitter | ||
from .fitter import * # noqa: F401, F403 | ||
|
||
# alls | ||
__all__ += fitter.__all__ | ||
|
||
|
||
############################################################################## | ||
# END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
"""**DOCSTRING**. | ||
Description. | ||
""" | ||
|
||
__all__ = [ | ||
"PotentialBase", | ||
] | ||
|
||
|
||
############################################################################## | ||
# IMPORTS | ||
|
||
# BUILT-IN | ||
import inspect | ||
import typing as T | ||
from abc import ABC, abstractmethod | ||
from types import ModuleType | ||
|
||
# THIRD PARTY | ||
from astropy.utils.introspection import resolve_name | ||
|
||
############################################################################## | ||
# PARAMETERS | ||
|
||
|
||
############################################################################## | ||
# CODE | ||
############################################################################## | ||
|
||
|
||
class PotentialBase(ABC): | ||
"""Sample a Potential. | ||
Raises | ||
------ | ||
TypeError | ||
On class declaration if metaclass argument 'package' is not a string | ||
or module. | ||
""" | ||
|
||
def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None): | ||
super().__init_subclass__() | ||
|
||
if package is not None: | ||
|
||
if isinstance(package, str): | ||
package = resolve_name(package) | ||
elif not isinstance(package, ModuleType): | ||
raise TypeError | ||
|
||
if package in cls._registry: | ||
raise Exception(f"`{package}` sampler already in registry.") | ||
|
||
cls._package = package | ||
|
||
# /def | ||
|
||
@property | ||
@abstractmethod | ||
def _registry(self): | ||
"""The class registry. Need to override.""" | ||
pass | ||
|
||
# /def | ||
|
||
################################################################# | ||
# utils | ||
|
||
@staticmethod | ||
def _infer_package( | ||
obj: T.Any, package: T.Union[ModuleType, str, None] = None | ||
): | ||
|
||
if inspect.ismodule(package): | ||
pass | ||
elif isinstance(package, str): | ||
package = resolve_name(package) | ||
elif package is None: # Need to get package from obj | ||
info = inspect.getmodule(obj) | ||
|
||
if info is None: # fails for c-compiled things | ||
package = obj.__class__.__module__ | ||
else: | ||
package = info.__package__ | ||
|
||
package = resolve_name(package.split(".")[0]) | ||
|
||
else: | ||
raise TypeError("package must be <module> or <str> or None.") | ||
|
||
return package | ||
|
||
# /def | ||
|
||
|
||
# /class | ||
|
||
|
||
# ------------------------------------------------------------------- | ||
|
||
|
||
############################################################################## | ||
# END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
"""Fit a Potential. | ||
Registering a Fitter | ||
******************** | ||
a | ||
""" | ||
|
||
|
||
__all__ = [ | ||
"PotentialFitter", | ||
] | ||
|
||
|
||
############################################################################## | ||
# IMPORTS | ||
|
||
# BUILT-IN | ||
import typing as T | ||
from types import MappingProxyType, ModuleType | ||
|
||
# PROJECT-SPECIFIC | ||
from .core import PotentialBase | ||
from discO.common import CoordinateType | ||
|
||
############################################################################## | ||
# PARAMETERS | ||
|
||
FITTER_REGISTRY = dict() # package : sampler | ||
# _fitter_package_registry = dict() # sampler : package | ||
|
||
|
||
############################################################################## | ||
# CODE | ||
############################################################################## | ||
|
||
|
||
class PotentialFitter(PotentialBase): | ||
"""Fit a Potential. | ||
Parameters | ||
---------- | ||
pot_type | ||
The type of potential with which to fit the data. | ||
Other Parameters | ||
---------------- | ||
package : `~types.ModuleType` or str or None (optional, keyword only) | ||
The package to which the `potential` belongs. | ||
If not provided (None, default) tries to infer from `potential`. | ||
return_specific_class : bool (optional, keyword only) | ||
Whether to return a `PotentialSampler` or package-specific subclass. | ||
This only applies if instantiating a `PotentialSampler`. | ||
Default False. | ||
""" | ||
|
||
_registry = MappingProxyType(FITTER_REGISTRY) | ||
|
||
def __init_subclass__(cls, package: T.Union[str, ModuleType]): | ||
super().__init_subclass__(package=package) | ||
|
||
FITTER_REGISTRY[cls._package] = cls | ||
|
||
# /defs | ||
|
||
def __new__( | ||
cls, | ||
pot_type: T.Any, | ||
*, | ||
package: T.Union[ModuleType, str, None] = None, | ||
return_specific_class: bool = False, | ||
): | ||
self = super().__new__(cls) | ||
|
||
if cls is PotentialFitter: | ||
package = self._infer_package(pot_type, package) | ||
instance = FITTER_REGISTRY[package](pot_type) | ||
|
||
if return_specific_class: | ||
return instance | ||
else: | ||
self._instance = instance | ||
|
||
return self | ||
|
||
# /def | ||
|
||
# def __init__(self, pot_type, **kwargs): | ||
# self._fitter = pot_type | ||
|
||
################################################################# | ||
# Fitting | ||
|
||
def __call__( | ||
self, | ||
c: CoordinateType, | ||
c_err: T.Optional[CoordinateType] = None, | ||
**kwargs, | ||
): | ||
return self._instance(c, c_err=c_err, **kwargs) | ||
|
||
# /def | ||
|
||
def fit( | ||
self, | ||
c: CoordinateType, | ||
c_err: T.Optional[CoordinateType] = None, | ||
**kwargs, | ||
): | ||
# pass to __call__ | ||
return self(c, c_err=c_err, **kwargs) | ||
|
||
# /def | ||
|
||
# # TODO? wrong place for this | ||
# def draw_realization(self, c, c_err=None, **kwargs): | ||
# """Draw a realization given the errors. | ||
|
||
# .. todo:: | ||
|
||
# rename this function | ||
|
||
# better than equal Gaussian errors | ||
|
||
# See Also | ||
# -------- | ||
# :meth:`~discO.core.sampler.draw_realization` | ||
|
||
# """ | ||
|
||
# # for i in range(nrlz): | ||
|
||
# # # FIXME! this is shit | ||
# # rep = c.represent_as(coord.CartesianRepresentation) | ||
# # rep_err = c_err.re | ||
|
||
# # new_c = c.realize_frame(new_rep) | ||
|
||
# # yield self(c, c_err=c_err, **kwarg) | ||
|
||
# # /def | ||
|
||
|
||
# /class | ||
|
||
# ------------------------------------------------------------------- | ||
|
||
|
||
############################################################################## | ||
# END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# -*- coding: utf-8 -*- | ||
# see LICENSE.rst | ||
|
||
"""**DOCSTRING**.""" | ||
|
||
|
||
__all__ = [] | ||
|
||
|
||
############################################################################## | ||
# IMPORTS | ||
|
||
# PROJECT-SPECIFIC | ||
from . import agama | ||
from .agama import * | ||
|
||
__all__ += agama.__all__ | ||
|
||
############################################################################## | ||
# END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# -*- coding: utf-8 -*- | ||
# see LICENSE.rst | ||
|
||
"""AGAMA interface.""" | ||
|
||
__all__ = [] | ||
|
||
|
||
############################################################################## | ||
# IMPORTS | ||
|
||
# PROJECT-SPECIFIC | ||
from . import fitter | ||
from .fitter import * # noqa: F401, F403 | ||
|
||
__all__ += fitter.__all__ | ||
|
||
|
||
############################################################################## | ||
# END |
Oops, something went wrong.