Skip to content
This repository has been archived by the owner on Aug 6, 2024. It is now read-only.

Commit

Permalink
Fitters to Samples of the Potential (#20)
Browse files Browse the repository at this point in the history
* fit stuff
* fit on multidimensional

Signed-off-by: Nathaniel Starkman (@nstarman) <nstarkman@protonmail.com>
Signed-off-by: Christopher Carr (@CCAstro35 )
  • Loading branch information
nstarman authored Jan 20, 2021
1 parent 4f4a865 commit 33b20a7
Show file tree
Hide file tree
Showing 22 changed files with 1,402 additions and 192 deletions.
62 changes: 48 additions & 14 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Modules:
- ``core`` : the base class. [#17]
- ``sample`` : for sampling from a Potential. [#17]
- ``measurement`` : for resampling, given observational errors. [#17]

- ``fitter`` : for fitting a Potential given a sample [#20]

**discO.core.core**

Expand All @@ -55,30 +55,48 @@ subclasses must override the ``_registry`` and ``__call__`` methods.

**discO.core.sample**

PotentialSampler : base class for sampling potentials [#17]
``PotentialSampler`` : base class for sampling potentials [#17]

+ registers subclasses. Each subclass is for sampling from potentials from
a different package. Eg. ``GalpyPotentialSampler`` for sampling ``galpy``
potentials.
a different package. Eg. ``GalpyPotentialSampler`` for sampling
``galpy`` potentials.
+ PotentialSampler can be used to initialize & wrap any of its subclasses.
This is controlled by the argument ``return_specific_class``. If False,
it returns the subclass itself.
+ Takes a ``potential`` and a ``frame`` (astropy CoordinateFrame). The
potential is used for sampling, but the resultant points are not located
potential is used for sampling, but the resulting points are not located
in any reference frame, which we assign with ``frame``.
+ ``__call__`` and ``sample`` are used to sample the potential
+ ``resample`` (and ``resampler``) sample the potential many times. This can
be done for many iterations and different sample number points.
+ ``sample`` samples the potential many times. This
can be done for many iterations and different sample number points.
+ ``sample_iter`` samples the potential many times as a generator.


**discO.core.measurement**

- MeasurementErrorSampler : abstract base class for resampling a potential given measurement errors [#17]
- ``MeasurementErrorSampler`` : base class for resampling a potential given
measurement errors [#17]

+ registers subclasses. Each subclass is for resampling in a different
way.
+ ``MeasurementErrorSampler`` is a registry wrapper class and can be used
in-place of any of its subclasses.

- ``GaussianMeasurementErrorSampler`` : uncorrelated Gaussian errors [#17]

+ registers subclasses. Each subclass is for resampling in a different way.
+ MeasurementErrorSampler can be used to wrap any of its subclasses.

- GaussianMeasurementErrorSampler : apply uncorrelated Gaussian errors [#17]
**discO.core.fitter**

- ``PotentialFitter`` : base class for fitting potentials [#20]

+ registers subclasses.
+ PotentialFitter can be used to initialize & wrap any of its subclasses.
This is controlled by the argument ``return_specific_class``. If False,
it returns the subclass itself.
+ Takes a ``potential_cls`` and ``key`` argument which are used to figure
out the desired subclass, and how to fit the potential.
+ ``__call__`` and ``fit`` are used to fit the potential, with the latter
working on N-D samples (multiple iterations).


discO.data
Expand All @@ -87,22 +105,38 @@ discO.data
- Add Milky_Way_Sim_100 data [#10]


discO.extern
discO.plugin
^^^^^^^^^^^^

Where classes for external packages are held.


discO.extern.agama
discO.plugin.agama
^^^^^^^^^^^^^^^^^^

- AGAMAPotentialSampler [#17]

+ Sample from ``agama`` potentials.
+ Subclass of ``PotentialSampler``
+ stores the mass and potential as attributes on the returned ``SkyCoord``

- AGAMAPotentialFitter [#20]

+ Fit ``agama`` potentials.
+ Subclass of ``PotentialFitter``
+ registers subclasses for different fit methods.
+ AGAMAPotentialFitter can be used to initialize & wrap any of its
subclasses. This is controlled by the argument ``return_specific_class``. If False, it returns the subclass itself.
+ Takes a ``pot_type`` argument which is used to figure
out the desired subclass, and how to fit the potential.

- AGAMAMultipolePotentialFitter [#20]

+ Fit ``agama`` potentials with a multipole
+ Subclass of ``AGAMAPotentialFitter``


discO.extern.galpy
discO.plugin.galpy
^^^^^^^^^^^^^^^^^^

- GalpyPotentialSampler [#17]
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include .mailmap
include *.yaml
include *.yml

recursive-include discO *.pyx *.c *.pxd *.cfg
recursive-include discO *.py *.pyx *.c *.pxd *.cfg *.galpyrc
recursive-include docs *
recursive-include licenses *
recursive-include scripts *
Expand Down
4 changes: 3 additions & 1 deletion discO/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# flatten structure

# PROJECT-SPECIFIC
from . import sample
from . import fitter, sample
from .fitter import * # noqa: F401, F403
from .measurement import * # noqa: F403
from .sample import * # noqa: F403

# alls
__all__ += sample.__all__
__all__ += measurement.__all__
__all__ += fitter.__all__


##############################################################################
Expand Down
87 changes: 64 additions & 23 deletions discO/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
import inspect
import typing as T
from abc import ABCMeta, abstractmethod
from types import ModuleType
from collections import Sequence
from types import MappingProxyType, ModuleType

# THIRD PARTY
from astropy.utils.decorators import classproperty
from astropy.utils.introspection import resolve_name

##############################################################################
# PARAMETERS


##############################################################################
# CODE
##############################################################################
Expand All @@ -43,21 +41,29 @@ class PotentialBase(metaclass=ABCMeta):
"""

#################################################################
#######################################################
# On the class

def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None):
def __init_subclass__(
cls,
key: T.Union[
str,
ModuleType,
None,
T.Sequence[T.Union[ModuleType, str]],
] = None,
):
"""Initialize a subclass.
This method applies to all subclasses, not matter the
inheritance depth, unless the MRO overrides.
Parameters
----------
package : str or `~types.ModuleType` or None (optional)
key : str or `~types.ModuleType` or None (optional)
If the package is not None, resolves package module
and stores it in attribute ``_package``.
If the key is not None, resolves key module
and stores it in attribute ``_key``.
.. todo::
Expand All @@ -66,17 +72,32 @@ def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None):
"""
super().__init_subclass__()

if package is not None:
if key is not None:
key = cls._parse_registry_path(key)

if key in cls._registry:
raise KeyError(f"`{key}` sampler already in registry.")

if isinstance(package, str):
package = resolve_name(package)
elif not isinstance(package, ModuleType):
raise TypeError
cls._key = key

if package in cls._registry:
raise KeyError(f"`{package}` sampler already in registry.")
# /def

def __class_getitem__(cls, key):
if isinstance(key, str):
item = cls._registry[key]
elif len(key) == 1:
item = cls._registry[key[0]]
else:
item = cls._registry[key[0]][key[1:]]

cls._package = package
return item

# /def

@classproperty
def registry(self):
"""The class registry."""
return MappingProxyType(self._registry)

# /def

Expand All @@ -87,11 +108,6 @@ def _registry(self):

# /def

def __class_getitem__(cls, key):
return cls._registry[key]

# /def

#################################################################
# On the instance

Expand Down Expand Up @@ -140,6 +156,31 @@ def _infer_package(

# /def

@staticmethod
def _parse_registry_path(path):

if isinstance(path, str):
parsed = path
elif isinstance(path, ModuleType):
parsed = path.__name__
elif isinstance(path, Sequence):
parsed = []
for p in path:
if isinstance(p, str):
parsed.append(p)
elif isinstance(p, ModuleType):
parsed.append(p.__name__)
else:
raise TypeError(
f"{path} is not <str, ModuleType, Sequence>",
)
else:
raise TypeError(f"{path} is not <str, ModuleType, Sequence>")

return parsed

# /def


# /class

Expand Down
Loading

0 comments on commit 33b20a7

Please sign in to comment.