Skip to content
This repository has been archived by the owner on Aug 6, 2024. It is now read-only.

Commit

Permalink
fit stuff
Browse files Browse the repository at this point in the history
Signed-off-by: Nathaniel Starkman (@nstarman) <nstarkman@protonmail.com>
  • Loading branch information
nstarman committed Dec 25, 2020
1 parent a540981 commit 46593c4
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 1 deletion.
7 changes: 6 additions & 1 deletion discO/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from ._astropy_init import __version__ # noqa # isort:skip

# PROJECT-SPECIFIC
from . import data
from . import core, data, extern
from .core import * # noqa: F401, F403

# All
__all__ += core.__all__
__all__ += extern.__all__

##############################################################################
# END
36 changes: 36 additions & 0 deletions discO/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-

"""Common code."""


# __all__ = [
# ]


##############################################################################
# IMPORTS

# BUILT-IN
import typing as T

# THIRD PARTY
import astropy.coordinates as coord
import astropy.units as u

##############################################################################
# PARAMETERS

EllipsisType = type(Ellipsis)

UnitType = T.TypeVar("Unit", bound=u.UnitBase)
QuantityType = T.TypeVar("Quantity", bound=u.Quantity)


FrameType = T.TypeVar("CoordinateFrame", bound=coord.BaseCoordinateFrame)
SkyCoordType = T.TypeVar("SkyCoord", bound=coord.SkyCoord)
CoordinateType = T.Union[FrameType, SkyCoordType]

FrameLikeType = T.Union[FrameType, SkyCoordType, str]

##############################################################################
# END
22 changes: 22 additions & 0 deletions discO/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
# see LICENSE.rst

"""core."""

__all__ = []


##############################################################################
# IMPORTS
# flatten structure

# PROJECT-SPECIFIC
from . import fitter
from .fitter import * # noqa: F401, F403

# alls
__all__ += fitter.__all__


##############################################################################
# END
108 changes: 108 additions & 0 deletions discO/core/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-

"""**DOCSTRING**.
Description.
"""

__all__ = [
"PotentialBase",
]


##############################################################################
# IMPORTS

# BUILT-IN
import inspect
import typing as T
from abc import ABC, abstractmethod
from types import ModuleType

# THIRD PARTY
from astropy.utils.introspection import resolve_name

##############################################################################
# PARAMETERS


##############################################################################
# CODE
##############################################################################


class PotentialBase(ABC):
"""Sample a Potential.
Raises
------
TypeError
On class declaration if metaclass argument 'package' is not a string
or module.
"""

def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None):
super().__init_subclass__()

if package is not None:

if isinstance(package, str):
package = resolve_name(package)
elif not isinstance(package, ModuleType):
raise TypeError

if package in cls._registry:
raise Exception(f"`{package}` sampler already in registry.")

cls._package = package

# /def

@property
@abstractmethod
def _registry(self):
"""The class registry. Need to override."""
pass

# /def

#################################################################
# utils

@staticmethod
def _infer_package(
obj: T.Any, package: T.Union[ModuleType, str, None] = None
):

if inspect.ismodule(package):
pass
elif isinstance(package, str):
package = resolve_name(package)
elif package is None: # Need to get package from obj
info = inspect.getmodule(obj)

if info is None: # fails for c-compiled things
package = obj.__class__.__module__
else:
package = info.__package__

package = resolve_name(package.split(".")[0])

else:
raise TypeError("package must be <module> or <str> or None.")

return package

# /def


# /class


# -------------------------------------------------------------------


##############################################################################
# END
153 changes: 153 additions & 0 deletions discO/core/fitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-

"""Fit a Potential.
Registering a Fitter
********************
a
"""


__all__ = [
"PotentialFitter",
]


##############################################################################
# IMPORTS

# BUILT-IN
import typing as T
from types import MappingProxyType, ModuleType

# PROJECT-SPECIFIC
from .core import PotentialBase
from discO.common import CoordinateType

##############################################################################
# PARAMETERS

FITTER_REGISTRY = dict() # package : sampler
# _fitter_package_registry = dict() # sampler : package


##############################################################################
# CODE
##############################################################################


class PotentialFitter(PotentialBase):
"""Fit a Potential.
Parameters
----------
pot_type
The type of potential with which to fit the data.
Other Parameters
----------------
package : `~types.ModuleType` or str or None (optional, keyword only)
The package to which the `potential` belongs.
If not provided (None, default) tries to infer from `potential`.
return_specific_class : bool (optional, keyword only)
Whether to return a `PotentialSampler` or package-specific subclass.
This only applies if instantiating a `PotentialSampler`.
Default False.
"""

_registry = MappingProxyType(FITTER_REGISTRY)

def __init_subclass__(cls, package: T.Union[str, ModuleType]):
super().__init_subclass__(package=package)

FITTER_REGISTRY[cls._package] = cls

# /defs

def __new__(
cls,
pot_type: T.Any,
*,
package: T.Union[ModuleType, str, None] = None,
return_specific_class: bool = False,
):
self = super().__new__(cls)

if cls is PotentialFitter:
package = self._infer_package(pot_type, package)
instance = FITTER_REGISTRY[package](pot_type)

if return_specific_class:
return instance
else:
self._instance = instance

return self

# /def

# def __init__(self, pot_type, **kwargs):
# self._fitter = pot_type

#################################################################
# Fitting

def __call__(
self,
c: CoordinateType,
c_err: T.Optional[CoordinateType] = None,
**kwargs,
):
return self._instance(c, c_err=c_err, **kwargs)

# /def

def fit(
self,
c: CoordinateType,
c_err: T.Optional[CoordinateType] = None,
**kwargs,
):
# pass to __call__
return self(c, c_err=c_err, **kwargs)

# /def

# # TODO? wrong place for this
# def draw_realization(self, c, c_err=None, **kwargs):
# """Draw a realization given the errors.

# .. todo::

# rename this function

# better than equal Gaussian errors

# See Also
# --------
# :meth:`~discO.core.sampler.draw_realization`

# """

# # for i in range(nrlz):

# # # FIXME! this is shit
# # rep = c.represent_as(coord.CartesianRepresentation)
# # rep_err = c_err.re

# # new_c = c.realize_frame(new_rep)

# # yield self(c, c_err=c_err, **kwarg)

# # /def


# /class

# -------------------------------------------------------------------


##############################################################################
# END
20 changes: 20 additions & 0 deletions discO/extern/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# see LICENSE.rst

"""**DOCSTRING**."""


__all__ = []


##############################################################################
# IMPORTS

# PROJECT-SPECIFIC
from . import agama
from .agama import *

__all__ += agama.__all__

##############################################################################
# END
20 changes: 20 additions & 0 deletions discO/extern/agama/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
# see LICENSE.rst

"""AGAMA interface."""

__all__ = []


##############################################################################
# IMPORTS

# PROJECT-SPECIFIC
from . import fitter
from .fitter import * # noqa: F401, F403

__all__ += fitter.__all__


##############################################################################
# END
Loading

0 comments on commit 46593c4

Please sign in to comment.