From 5dff71aca679d3345d549c1fd3c49db8a20535ae Mon Sep 17 00:00:00 2001 From: "Nathaniel Starkman (@nstarman)" Date: Tue, 19 Jan 2021 14:26:34 -0500 Subject: [PATCH] fit on multidimensional Signed-off-by: Nathaniel Starkman (@nstarman) --- CHANGES.rst | 62 +++- MANIFEST.in | 2 +- discO/core/__init__.py | 2 +- discO/core/core.py | 87 +++-- discO/core/fitter.py | 195 +++++++---- discO/core/measurement.py | 4 +- discO/core/sample.py | 54 ++-- discO/core/tests/__init__.py | 4 +- discO/core/tests/test_core.py | 57 +++- discO/core/tests/test_fitter.py | 408 ++++++++++++++++++++++++ discO/core/tests/test_measurement.py | 2 +- discO/core/tests/test_sample.py | 84 ++--- discO/extern/__init__.py | 20 -- discO/extern/agama/__init__.py | 20 -- discO/extern/agama/fitter.py | 68 ---- discO/plugin/agama/__init__.py | 4 +- discO/plugin/agama/fitter.py | 158 +++++++++ discO/plugin/agama/sample.py | 4 +- discO/plugin/agama/tests/__init__.py | 2 + discO/plugin/agama/tests/test_fitter.py | 180 +++++++++++ discO/plugin/agama/tests/test_sample.py | 6 +- discO/plugin/galpy/sample.py | 19 +- discO/plugin/galpy/tests/.galpyrc | 13 + discO/plugin/galpy/tests/test_sample.py | 12 +- docs/examples/sampling.ipynb | 194 ++++++++--- 25 files changed, 1304 insertions(+), 357 deletions(-) create mode 100644 discO/core/tests/test_fitter.py delete mode 100644 discO/extern/__init__.py delete mode 100644 discO/extern/agama/__init__.py delete mode 100644 discO/extern/agama/fitter.py create mode 100644 discO/plugin/agama/fitter.py create mode 100644 discO/plugin/agama/tests/test_fitter.py create mode 100644 discO/plugin/galpy/tests/.galpyrc diff --git a/CHANGES.rst b/CHANGES.rst index f7672e13..9fc479bb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -41,7 +41,7 @@ Modules: - ``core`` : the base class. [#17] - ``sample`` : for sampling from a Potential. [#17] - ``measurement`` : for resampling, given observational errors. [#17] - +- ``fitter`` : for fitting a Potential given a sample [#20] **discO.core.core** @@ -55,30 +55,48 @@ subclasses must override the ``_registry`` and ``__call__`` methods. **discO.core.sample** -PotentialSampler : base class for sampling potentials [#17] +``PotentialSampler`` : base class for sampling potentials [#17] + registers subclasses. Each subclass is for sampling from potentials from - a different package. Eg. ``GalpyPotentialSampler`` for sampling ``galpy`` - potentials. + a different package. Eg. ``GalpyPotentialSampler`` for sampling + ``galpy`` potentials. + PotentialSampler can be used to initialize & wrap any of its subclasses. This is controlled by the argument ``return_specific_class``. If False, it returns the subclass itself. + Takes a ``potential`` and a ``frame`` (astropy CoordinateFrame). The - potential is used for sampling, but the resultant points are not located + potential is used for sampling, but the resulting points are not located in any reference frame, which we assign with ``frame``. + ``__call__`` and ``sample`` are used to sample the potential - + ``resample`` (and ``resampler``) sample the potential many times. This can - be done for many iterations and different sample number points. + + ``sample`` samples the potential many times. This + can be done for many iterations and different sample number points. + + ``sample_iter`` samples the potential many times as a generator. **discO.core.measurement** -- MeasurementErrorSampler : abstract base class for resampling a potential given measurement errors [#17] +- ``MeasurementErrorSampler`` : base class for resampling a potential given + measurement errors [#17] + + + registers subclasses. Each subclass is for resampling in a different + way. + + ``MeasurementErrorSampler`` is a registry wrapper class and can be used + in-place of any of its subclasses. + +- ``GaussianMeasurementErrorSampler`` : uncorrelated Gaussian errors [#17] - + registers subclasses. Each subclass is for resampling in a different way. - + MeasurementErrorSampler can be used to wrap any of its subclasses. -- GaussianMeasurementErrorSampler : apply uncorrelated Gaussian errors [#17] +**discO.core.fitter** + +- ``PotentialFitter`` : base class for fitting potentials [#20] + + + registers subclasses. + + PotentialFitter can be used to initialize & wrap any of its subclasses. + This is controlled by the argument ``return_specific_class``. If False, + it returns the subclass itself. + + Takes a ``potential_cls`` and ``key`` argument which are used to figure + out the desired subclass, and how to fit the potential. + + ``__call__`` and ``fit`` are used to fit the potential, with the latter + working on N-D samples (multiple iterations). discO.data @@ -87,22 +105,38 @@ discO.data - Add Milky_Way_Sim_100 data [#10] -discO.extern +discO.plugin ^^^^^^^^^^^^ Where classes for external packages are held. -discO.extern.agama +discO.plugin.agama ^^^^^^^^^^^^^^^^^^ - AGAMAPotentialSampler [#17] + Sample from ``agama`` potentials. + + Subclass of ``PotentialSampler`` + stores the mass and potential as attributes on the returned ``SkyCoord`` +- AGAMAPotentialFitter [#20] + + + Fit ``agama`` potentials. + + Subclass of ``PotentialFitter`` + + registers subclasses for different fit methods. + + AGAMAPotentialFitter can be used to initialize & wrap any of its + subclasses. This is controlled by the argument ``return_specific_class``. If False, it returns the subclass itself. + + Takes a ``pot_type`` argument which is used to figure + out the desired subclass, and how to fit the potential. + +- AGAMAMultipolePotentialFitter [#20] + + + Fit ``agama`` potentials with a multipole + + Subclass of ``AGAMAPotentialFitter`` + -discO.extern.galpy +discO.plugin.galpy ^^^^^^^^^^^^^^^^^^ - GalpyPotentialSampler [#17] diff --git a/MANIFEST.in b/MANIFEST.in index 8225fb6a..4f6184b1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,7 +8,7 @@ include .mailmap include *.yaml include *.yml -recursive-include discO *.pyx *.c *.pxd *.cfg +recursive-include discO *.py *.pyx *.c *.pxd *.cfg *.galpyrc recursive-include docs * recursive-include licenses * recursive-include scripts * diff --git a/discO/core/__init__.py b/discO/core/__init__.py index 1e81e636..96d44a4f 100644 --- a/discO/core/__init__.py +++ b/discO/core/__init__.py @@ -11,7 +11,7 @@ # flatten structure # PROJECT-SPECIFIC -from . import sample, fitter +from . import fitter, sample from .fitter import * # noqa: F401, F403 from .measurement import * # noqa: F403 from .sample import * # noqa: F403 diff --git a/discO/core/core.py b/discO/core/core.py index 142497e3..c68fcd29 100644 --- a/discO/core/core.py +++ b/discO/core/core.py @@ -14,15 +14,13 @@ import inspect import typing as T from abc import ABCMeta, abstractmethod -from types import ModuleType +from collections import Sequence +from types import MappingProxyType, ModuleType # THIRD PARTY +from astropy.utils.decorators import classproperty from astropy.utils.introspection import resolve_name -############################################################################## -# PARAMETERS - - ############################################################################## # CODE ############################################################################## @@ -43,10 +41,18 @@ class PotentialBase(metaclass=ABCMeta): """ - ################################################################# + ####################################################### # On the class - def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None): + def __init_subclass__( + cls, + key: T.Union[ + str, + ModuleType, + None, + T.Sequence[T.Union[ModuleType, str]], + ] = None, + ): """Initialize a subclass. This method applies to all subclasses, not matter the @@ -54,10 +60,10 @@ def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None): Parameters ---------- - package : str or `~types.ModuleType` or None (optional) + key : str or `~types.ModuleType` or None (optional) - If the package is not None, resolves package module - and stores it in attribute ``_package``. + If the key is not None, resolves key module + and stores it in attribute ``_key``. .. todo:: @@ -66,17 +72,32 @@ def __init_subclass__(cls, package: T.Union[str, ModuleType, None] = None): """ super().__init_subclass__() - if package is not None: + if key is not None: + key = cls._parse_registry_path(key) + + if key in cls._registry: + raise KeyError(f"`{key}` sampler already in registry.") - if isinstance(package, str): - package = resolve_name(package) - elif not isinstance(package, ModuleType): - raise TypeError + cls._key = key - if package in cls._registry: - raise KeyError(f"`{package}` sampler already in registry.") + # /def + + def __class_getitem__(cls, key): + if isinstance(key, str): + item = cls._registry[key] + elif len(key) == 1: + item = cls._registry[key[0]] + else: + item = cls._registry[key[0]][key[1:]] - cls._package = package + return item + + # /def + + @classproperty + def registry(self): + """The class registry.""" + return MappingProxyType(self._registry) # /def @@ -87,11 +108,6 @@ def _registry(self): # /def - def __class_getitem__(cls, key): - return cls._registry[key] - - # /def - ################################################################# # On the instance @@ -140,6 +156,31 @@ def _infer_package( # /def + @staticmethod + def _parse_registry_path(path): + + if isinstance(path, str): + parsed = path + elif isinstance(path, ModuleType): + parsed = path.__name__ + elif isinstance(path, Sequence): + parsed = [] + for p in path: + if isinstance(p, str): + parsed.append(p) + elif isinstance(p, ModuleType): + parsed.append(p.__name__) + else: + raise TypeError( + f"{path} is not ", + ) + else: + raise TypeError(f"{path} is not ") + + return parsed + + # /def + # /class diff --git a/discO/core/fitter.py b/discO/core/fitter.py index d9c439f8..f86ff482 100644 --- a/discO/core/fitter.py +++ b/discO/core/fitter.py @@ -18,8 +18,13 @@ # IMPORTS # BUILT-IN +import inspect import typing as T -from types import MappingProxyType, ModuleType +import warnings +from types import ModuleType + +# THIRD PARTY +import numpy as np # PROJECT-SPECIFIC from .core import PotentialBase @@ -28,9 +33,7 @@ ############################################################################## # PARAMETERS -FITTER_REGISTRY = dict() # package : sampler -# _fitter_package_registry = dict() # sampler : package - +FITTER_REGISTRY = dict() # package : samplers ############################################################################## # CODE @@ -40,109 +43,187 @@ class PotentialFitter(PotentialBase): """Fit a Potential. + .. todo:: + + This is registering by the package, which I think may be the wrong + approach most packages have multiple fitters, which should be easily + accessible. + Parameters ---------- - pot_type + potential_cls The type of potential with which to fit the data. Other Parameters ---------------- - package : `~types.ModuleType` or str or None (optional, keyword only) - The package to which the `potential` belongs. + key : `~types.ModuleType` or str or None (optional, keyword only) + The key to which the `potential` belongs. If not provided (None, default) tries to infer from `potential`. return_specific_class : bool (optional, keyword only) - Whether to return a `PotentialSampler` or package-specific subclass. + Whether to return a `PotentialSampler` or key-specific subclass. This only applies if instantiating a `PotentialSampler`. Default False. """ - _registry = MappingProxyType(FITTER_REGISTRY) + ################################################################# + # On the class + + _registry = FITTER_REGISTRY + + def __init_subclass__(cls, key: T.Union[str, ModuleType] = None): + """Initialize subclass, adding to registry by `key`. + + This method applies to all subclasses, not matter the + inheritance depth, unless the MRO overrides. - def __init_subclass__(cls, package: T.Union[str, ModuleType]): - super().__init_subclass__(package=package) + """ + super().__init_subclass__(key=key) - FITTER_REGISTRY[cls._package] = cls + if key is not None: # same trigger as PotentialBase + # get the registry on this (the parent) object + # cls._key defined in super() + cls.__bases__[0]._registry[cls._key] = cls + + # TODO? insist that subclasses define a __call__ method + # this "abstractifies" the base-class even though it can be used + # as a wrapper class. # /defs + ################################################################# + # On the instance + def __new__( cls, - pot_type: T.Any, + potential_cls: T.Any, *, - package: T.Union[ModuleType, str, None] = None, + key: T.Union[ModuleType, str, None] = None, return_specific_class: bool = False, + **kwargs, ): self = super().__new__(cls) + self._fitter = potential_cls + # The class PotentialFitter is a wrapper for anything in its registry + # If directly instantiating a PotentialFitter (not subclass) we must + # also instantiate the appropriate subclass. Error if can't find. if cls is PotentialFitter: - package = self._infer_package(pot_type, package) - instance = FITTER_REGISTRY[package](pot_type) + # infer the key, to add to registry + key = self._infer_package(potential_cls, key).__name__ + + if key not in cls._registry: + raise ValueError( + "PotentialFitter has no registered fitter for key: " + f"{key}", + ) + + # from registry. Registered in __init_subclass__ + # some subclasses accept the potential_cls as an argument, + # others do not. + subcls = cls._registry[key] + sig = inspect.signature(subcls) + ba = sig.bind_partial(potential_cls=potential_cls, **kwargs) + ba.apply_defaults() + + instance = cls._registry[key](*ba.args, **ba.kwargs) if return_specific_class: return instance - else: - self._instance = instance + + self._instance = instance + + elif key is not None: + raise ValueError( + "Can't specify 'key' on PotentialFitter subclasses.", + ) + + elif return_specific_class is not False: + warnings.warn("Ignoring argument `return_specific_class`") return self # /def - # def __init__(self, pot_type, **kwargs): - # self._fitter = pot_type - - ################################################################# + ####################################################### # Fitting def __call__( self, - c: CoordinateType, - c_err: T.Optional[CoordinateType] = None, + sample: CoordinateType, + # sample_err: T.Optional[CoordinateType] = None, **kwargs, ): - return self._instance(c, c_err=c_err, **kwargs) + """Fit. + + Parameters + ---------- + sample : `SkyCoord` + **kwargs + passed to underlying instance + + Returns + ------- + Potential : object + + """ + # call on instance + return self._instance( + sample, + # c_err=c_err, + **kwargs, + ) # /def def fit( self, - c: CoordinateType, - c_err: T.Optional[CoordinateType] = None, + sample: CoordinateType, **kwargs, ): - # pass to __call__ - return self(c, c_err=c_err, **kwargs) + """Fit. + + .. todo:: + + Subclass SkyCoord and have metadata mass and potential that + carry-over. Or embed a SkyCoord in a table with the other + attributes. or something so that doesn't need continual + reassignment + + Parameters + ---------- + sample : `SkyCoord` + can have shape (nsamp, ) or (nsamp, niter) + # sample_err: T.Optional[CoordinateType] = None, + **kwargs + passed to underlying instance + + Returns + ------- + Potential : object + + """ + if len(sample.shape) == 1: # (nsamp, ) -> (nsamp, niter=1) + mass, potential = sample.mass, sample.potential + sample = sample.reshape((-1, 1)) + sample.mass, sample.potential = mass.reshape((-1, 1)), potential + + # shape (niter, ) + niter = sample.shape[1] + fits = np.empty(niter, dtype=sample.potential.__class__) + + # (niter, nsamp) -> iter on niter + for i, (samp, mass) in enumerate(zip(sample.T, sample.mass.T)): + samp.mass, samp.potential = mass, sample.potential + fits[i] = self(samp, **kwargs) + + if niter == 1: + return fits[0] + else: + return fits # /def - # # TODO? wrong place for this - # def draw_realization(self, c, c_err=None, **kwargs): - # """Draw a realization given the errors. - - # .. todo:: - - # rename this function - - # better than equal Gaussian errors - - # See Also - # -------- - # :meth:`~discO.core.sampler.draw_realization` - - # """ - - # # for i in range(nrlz): - - # # # FIXME! this is shit - # # rep = c.represent_as(coord.CartesianRepresentation) - # # rep_err = c_err.re - - # # new_c = c.realize_frame(new_rep) - - # # yield self(c, c_err=c_err, **kwarg) - - # # /def - # /class diff --git a/discO/core/measurement.py b/discO/core/measurement.py index e8016082..69b24d32 100644 --- a/discO/core/measurement.py +++ b/discO/core/measurement.py @@ -39,7 +39,7 @@ ############################################################################## # PARAMETERS -MEASURE_REGISTRY = dict() # package : measurer +MEASURE_REGISTRY = dict() # key : measurer ############################################################################## # CODE @@ -98,7 +98,7 @@ def __init_subclass__(cls): inheritance depth, unless the MRO overrides. """ - super().__init_subclass__(package=None) + super().__init_subclass__(key=None) key = cls.__name__ if key in cls._registry: diff --git a/discO/core/sample.py b/discO/core/sample.py index 3205fc36..791f8686 100644 --- a/discO/core/sample.py +++ b/discO/core/sample.py @@ -41,7 +41,7 @@ class GalpyPotentialSampler(PotentialSampler): import typing as T import warnings from contextlib import nullcontext -from types import MappingProxyType, ModuleType +from types import ModuleType # THIRD PARTY import numpy as np @@ -56,7 +56,7 @@ class GalpyPotentialSampler(PotentialSampler): ############################################################################## # PARAMETERS -SAMPLER_REGISTRY = dict() # package : sampler +SAMPLER_REGISTRY = dict() # key : sampler Random_Like = T.Union[int, np.random.Generator, np.random.RandomState, None] @@ -83,8 +83,8 @@ class PotentialSampler(PotentialBase): Other Parameters ---------------- - package : `~types.ModuleType` or str or None (optional, keyword only) - The package to which the `potential` belongs. + key : `~types.ModuleType` or str or None (optional, keyword only) + The key to which the `potential` belongs. If not provided (None, default) tries to infer from `potential`. return_specific_class : bool (optional, keyword only) Whether to return a `PotentialSampler` or package-specific subclass. @@ -95,27 +95,27 @@ class PotentialSampler(PotentialBase): ------ ValueError If directly instantiating a PotentialSampler (not subclass) and cannot - find the appropriate subclass, identified using ``package``. + find the appropriate subclass, identified using ``key``. """ ################################################################# # On the class - _registry = MappingProxyType(SAMPLER_REGISTRY) + _registry = SAMPLER_REGISTRY - def __init_subclass__(cls, package: T.Union[str, ModuleType] = None): - """Initialize subclass, adding to registry by `package`. + def __init_subclass__(cls, key: T.Union[str, ModuleType] = None): + """Initialize subclass, adding to registry by `key`. This method applies to all subclasses, not matter the inheritance depth, unless the MRO overrides. """ - super().__init_subclass__(package=package) + super().__init_subclass__(key=key) - if package is not None: # same trigger as PotentialBase - # cls._package defined in super() - SAMPLER_REGISTRY[cls._package] = cls + if key is not None: # same trigger as PotentialBase + # cls._key defined in super() + cls.__bases__[0]._registry[cls._key] = cls # TODO? insist that subclasses define a __call__ method # this "abstractifies" the base-class even though it can be used @@ -131,7 +131,7 @@ def __new__( potential: T.Any, *, frame: T.Optional[FrameLikeType] = None, - package: T.Union[ModuleType, str, None] = None, + key: T.Union[ModuleType, str, None] = None, return_specific_class: bool = False, ): self = super().__new__(cls) @@ -140,17 +140,17 @@ def __new__( # If directly instantiating a PotentialSampler (not subclass) we must # also instantiate the appropriate subclass. Error if can't find. if cls is PotentialSampler: - # infer the package, to add to registry - package = self._infer_package(potential, package) + # infer the key, to add to registry + key = self._infer_package(potential, key).__name__ - if package not in cls._registry: + if key not in cls._registry: raise ValueError( - "PotentialSampler has no registered sampler for package: " - f"{package}", + "PotentialSampler has no registered sampler for key: " + f"{key}", ) # from registry. Registered in __init_subclass__ - instance = cls[package](potential) + instance = cls[key](potential) # Whether to return class or subclass # else continue, storing instance @@ -159,9 +159,9 @@ def __new__( self._instance = instance - elif package is not None: + elif key is not None: raise ValueError( - "Can't specify 'package' on PotentialSampler subclasses.", + "Can't specify 'key' on PotentialSampler subclasses.", ) elif return_specific_class is not False: @@ -280,6 +280,13 @@ def sample( ): """Sample the potential. + .. todo:: + + Subclass SkyCoord and have metadata mass and potential that + carry-over. Or embed a SkyCoord in a table with the other + attributes. or something so that doesn't need continual + reassignment + Parameters ---------- n : int or sequence @@ -323,15 +330,20 @@ def sample( for i, N in enumerate(itersamp): samps = [None] * niter # premake array + mass = [None] * niter # premake array for j in range(0, niter): samp = self(n=N, frame=frame, random=random, **kwargs) samps[j] = samp + mass[j] = samp.mass if j == 0: # 0-dimensional doesn't need concat sample = samps[0] else: sample = concatenate(samps).reshape((N, niter)) + sample.mass = np.vstack(mass).T + sample.potential = samp.potential # all the same + samples[i] = sample if np.isscalar(n): diff --git a/discO/core/tests/__init__.py b/discO/core/tests/__init__.py index e17acfd8..025ea21f 100644 --- a/discO/core/tests/__init__.py +++ b/discO/core/tests/__init__.py @@ -6,8 +6,9 @@ __all__ = [ # modules "core_tests", - "measurement_tests", "sample_tests", + "measurement_tests", + "fitter_tests", # instance "test", ] @@ -24,6 +25,7 @@ # PROJECT-SPECIFIC from . import test_core as core_tests +from . import test_fitter as fitter_tests from . import test_measurement as measurement_tests from . import test_sample as sample_tests diff --git a/discO/core/tests/test_core.py b/discO/core/tests/test_core.py index 0322aaf2..472faa21 100644 --- a/discO/core/tests/test_core.py +++ b/discO/core/tests/test_core.py @@ -13,6 +13,7 @@ # BUILT-IN from abc import abstractmethod from collections.abc import Mapping +from types import MappingProxyType # THIRD PARTY import pytest @@ -36,26 +37,26 @@ class Test_PotentialBase(ObjectTest, obj=core.PotentialBase): def test___init_subclass__(self): """Test subclassing.""" # -------------------- - # When package is None + # When key is None class SubClasss1(self.obj): _registry = {} - assert not hasattr(SubClasss1, "_package") + assert not hasattr(SubClasss1, "_key") # -------------------- - # When package is str + # When key is str - class SubClasss2(self.obj, package="pytest"): + class SubClasss2(self.obj, key="pytest"): _registry = {} - assert SubClasss2._package == pytest + assert SubClasss2._key == "pytest" # -------------------- # test error with pytest.raises(TypeError): - class SubClasss3(self.obj, package=Exception): + class SubClasss3(self.obj, key=Exception): _registry = {} # /def @@ -69,6 +70,12 @@ def test__registry(self): # /def + def test_registry(self): + # This doesn't run on `Test_PotentialBase`, but should + # run on all registry subclasses. + if isinstance(self.obj._registry, Mapping): + assert isinstance(self.obj.registry, MappingProxyType) + # ------------------------------- @abstractmethod @@ -78,13 +85,23 @@ def test___class_getitem__(self): # or a Mapping, for normal classes. assert isinstance(self.obj._registry, (property, Mapping)) + # --------- + # This doesn't run on `Test_PotentialBase`, but should # run on all registry subclasses. if isinstance(self.obj._registry, Mapping): # a very basic equality test for k in self.obj._registry: + # str assert self.obj[k] is self.obj._registry[k] + # iterable of len = 1 + assert self.obj[[k]] is self.obj._registry[k] + + # multi-length iterable that fails + with pytest.raises(KeyError): + self.obj[[k, KeyError]] + # /def # ------------------------------- @@ -116,7 +133,7 @@ def test___call__(self): def test__infer_package(self): """Test method ``_infer_package``.""" - # when package is None + # when key is None assert self.obj._infer_package(self.obj) == discO # when pass package @@ -138,8 +155,32 @@ def test__infer_package(self): # /def + # ------------------------------- + + def test__parse_registry_path(self): + """Test method ``_parse_registry_path``.""" + # str -> str + assert self.obj._parse_registry_path("pytest") == "pytest" + + # module -> str + assert self.obj._parse_registry_path(pytest) == "pytest" + + # Sequence + assert self.obj._parse_registry_path(("pytest", discO)) == [ + "pytest", + "discO", + ] + + # failure in Sequence + with pytest.raises(TypeError): + self.obj._parse_registry_path((None,)) + + # failure in normal call + with pytest.raises(TypeError): + self.obj._parse_registry_path(None) + ################################################################# - # Pipeline Tests + # Usage Tests # N/A b/c abstract base-class diff --git a/discO/core/tests/test_fitter.py b/discO/core/tests/test_fitter.py new file mode 100644 index 00000000..bbb3ec3a --- /dev/null +++ b/discO/core/tests/test_fitter.py @@ -0,0 +1,408 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`~discO.core.fitter`.""" + +__all__ = [ + "Test_PotentialFitter", +] + + +############################################################################## +# IMPORTS + +# BUILT-IN +import inspect +from abc import abstractmethod + +# THIRD PARTY +import astropy.coordinates as coord +import astropy.units as u +import numpy as np +import pytest + +# PROJECT-SPECIFIC +from discO.core import fitter +from discO.core.tests.test_core import Test_PotentialBase as PotentialBase_Test + +############################################################################## +# PARAMETERS + +crd = coord.SkyCoord( + coord.ICRS( + ra=[ + 269.77, + 211.53, + 135.49, + 3.85, + 42.11, + 212.56, + 203.11, + 61.49, + 344.11, + 98.63, + ] + * u.deg, + dec=[ + -80.39242629, + -3.67881258, + -44.62636438, + -7.46999137, + -20.90390085, + -64.15957604, + -9.16456976, + -33.66474899, + -41.05292432, + 16.56923216, + ] + * u.deg, + distance=[ + 12.15818053, + 7.37721302, + 156.25820005, + 5.08874191, + 7.7856392, + 16.58761413, + 6.31741618, + 3.83061213, + 4.97326983, + 32.21408322, + ] + * u.kpc, + ), +) +crd.mass = np.ones(10) * u.solMass + +multicrd = crd.reshape((5, 2)) +multicrd.mass = crd.mass.reshape((5, 2)) + + +class FitterSubClass(fitter.PotentialFitter, key="test_discO"): + def __call__(self, c, **kwargs): + c.represent_as(coord.CartesianRepresentation) + return object() + + # /def + + +# /class + + +############################################################################## +# PYTEST + + +def teardown_module(module): + """Teardown any state that was previously set up.""" + FitterSubClass._registry.pop("test_discO", None) + + +# /def + +############################################################################## +# TESTS +############################################################################## + + +class Test_PotentialFitter(PotentialBase_Test, obj=fitter.PotentialFitter): + @classmethod + def setup_class(cls): + """Setup fixtures for testing.""" + cls.potential = object + + # register a unittest examples + class SubClassUnitTest(cls.obj, key="unittest"): + def __call__(self, c, **kwargs): + c.represent_as(coord.CartesianRepresentation) + return cls.potential() + + cls.SubClassUnitTest = SubClassUnitTest + + # make instance. It depends. + if cls.obj is fitter.PotentialFitter: + cls.inst = cls.obj(cls.potential, key="unittest") + + # /def + + @classmethod + def teardown_class(cls): + """Teardown fixtures for testing.""" + cls.SubClassUnitTest._registry.pop("unittest", None) + + # /def + + ################################################################# + # Method Tests + + def test___init_subclass__(self): + """Test subclassing.""" + # When package is None, it is not registered + class SubClass1(self.obj): + pass + + assert None not in fitter.FITTER_REGISTRY + assert SubClass1 not in fitter.FITTER_REGISTRY.values() + + # ------------------------ + # register a new + try: + + class SubClass1(self.obj, key="pytest"): + pass + + except Exception: + pass + finally: + fitter.FITTER_REGISTRY.pop("pytest", None) + + # ------------------------------- + # error when already in registry + + try: + # registered + class SubClass1(self.obj, key="pytest"): + pass + + # doing it again raises error + with pytest.raises(KeyError): + + class SubClass1(self.obj, key="pytest"): + pass + + except Exception: + pass + finally: # cleanup + fitter.FITTER_REGISTRY.pop("pytest", None) + + # /def + + # ------------------------------- + + def test__registry(self): + """Test method ``_registry``. + + As ``_registry`` is never overwritten in the subclasses, this test + should carry though. + + """ + # run tests on super + super().test__registry() + + # ------------------------------- + assert isinstance(self.obj._registry, dict) + + # The unittest is already registered, so can + # test for that. + assert "unittest" in self.obj._registry.keys() + assert self.SubClassUnitTest in self.obj._registry.values() + assert self.obj._registry["unittest"] is self.SubClassUnitTest + + # /def + + # ------------------------------- + + def test___class_getitem__(self): + """Test method ``__class_getitem__``.""" + # run tests on super + super().test___class_getitem__() + + # ------------------------------- + # test a specific item in the registry + assert self.obj["unittest"] is self.SubClassUnitTest + + # /def + + # ------------------------------- + + def test___new__(self): + """Test method ``__new__``. + + This is a wrapper class that acts differently when instantiating + a MeasurementErrorSampler than one of it's subclasses. + + """ + # there are no tests on super + # super().test___new__() + + # -------------------------- + if self.obj is fitter.PotentialFitter: + + # --------------- + # Need the "potential" argument + with pytest.raises(TypeError) as e: + self.obj() + + assert ( + "missing 1 required positional argument: 'potential_cls'" + ) in str(e.value) + + # -------------------------- + # for object not in registry + + with pytest.raises(ValueError) as e: + self.obj(self.potential()) + + assert ( + "PotentialFitter has no registered fitter for key: builtins" + ) in str(e.value) + + # --------------- + # with return_specific_class + + klass = self.obj._registry["unittest"] + + msamp = self.obj( + self.potential, + key="unittest", + return_specific_class=True, + ) + + # test class type + assert isinstance(msamp, klass) + assert isinstance(msamp, self.obj) + + # test inputs + assert msamp._fitter == self.potential + + # --------------- + # as wrapper class + + klass = self.obj._registry["unittest"] + + msamp = self.obj( + self.potential, + key="unittest", + return_specific_class=False, + ) + + # test class type + assert not isinstance(msamp, klass) + assert isinstance(msamp, self.obj) + assert isinstance(msamp._instance, klass) + + # test inputs + assert msamp._fitter == self.potential + + # -------------------------- + else: # never hit in Test_PotentialSampler, only in subs + + # --------------- + # Can't have the "key" argument + + with pytest.raises(ValueError) as e: + sig = inspect.signature(self.obj) + ba = sig.bind_partial( + potential_cls=self.potential, + key="not None", + ) + ba.apply_defaults() + self.obj(*ba.args, **ba.kwargs) + + assert "Can't specify 'key'" in str(e.value) + + # --------------- + # warning + + with pytest.warns(UserWarning): + self.obj( + self.potential, + key=None, + return_specific_class=True, + ) + + # --------------- + # AOK + + msamp = self.obj(self.potential, frame="icrs") + + assert self.obj is not fitter.PotentialFitter + assert isinstance(msamp, self.obj) + assert isinstance(msamp, fitter.PotentialFitter) + assert not hasattr(msamp, "_instance") + assert msamp._fitter == self.potential + + # /def + + # ------------------------------- + + @abstractmethod + def test___init__(self): + """Test method ``__init__``.""" + # run tests on super + super().test___init__() + + # -------------------------- + pass # for subclasses. The setup_class actually tests this for here. + + # /def + + # ------------------------------- + + def test___call__(self): + """Test method ``__call__``. + + When Test_MeasurementErrorSampler this calls on the wrapped instance, + which is GaussianMeasurementErrorSampler. + + We can't test the output, but can test that it "works". + + """ + # run tests on super + super().test___call__() + + # /def + + # TODO! with hypothesis + @pytest.mark.parametrize("sample", [crd]) + def test_call_parametrize(self, sample): + """Parametrized call tests.""" + res = self.inst(sample) + assert isinstance(res, self.potential) + + # /def + + # ------------------------------- + + # TODO! with hypothesis + @pytest.mark.parametrize("sample", [crd, multicrd]) + def test_fit(self, sample): + """Test method ``fit``.""" + # for test need to assign correct potential type + sample.potential = self.potential + + pots = self.inst.fit(sample) + + if len(sample.shape) == 1: + assert isinstance(pots, sample.potential) + + else: + assert isinstance(pots, np.ndarray) + assert len(pots) == sample.shape[1] + assert all([isinstance(p, sample.potential) for p in pots]) + + # and then cleanup + del sample.potential + + # /def + + +############################################################################## + + +# ------------------------------------------------------------------- + + +class Test_PotentialFitter_SubClass( + Test_PotentialFitter, + obj=FitterSubClass, +): + @classmethod + def setup_class(cls): + """Setup fixtures for testing.""" + super().setup_class() + cls.inst = cls.obj(cls.potential) + + # /def + + +############################################################################## +# END diff --git a/discO/core/tests/test_measurement.py b/discO/core/tests/test_measurement.py index 5332a94e..91499be0 100644 --- a/discO/core/tests/test_measurement.py +++ b/discO/core/tests/test_measurement.py @@ -62,7 +62,7 @@ def test___init_subclass__(self): class SubClass1(self.obj): pass - assert not hasattr(SubClass1, "_package") + assert not hasattr(SubClass1, "_key") assert "SubClass1" in measurement.MEASURE_REGISTRY except Exception: pass diff --git a/discO/core/tests/test_sample.py b/discO/core/tests/test_sample.py index f2bdc6ed..71a9bbee 100644 --- a/discO/core/tests/test_sample.py +++ b/discO/core/tests/test_sample.py @@ -12,9 +12,7 @@ # BUILT-IN import itertools -import unittest from abc import abstractmethod -from types import MappingProxyType # THIRD PARTY import astropy.coordinates as coord @@ -24,21 +22,21 @@ # PROJECT-SPECIFIC from discO.core import sample -from discO.core.tests.test_core import Test_PotentialBase +from discO.core.tests.test_core import Test_PotentialBase as PotentialBase_Test ############################################################################## # TESTS ############################################################################## -class Test_PotentialSampler(Test_PotentialBase, obj=sample.PotentialSampler): +class Test_PotentialSampler(PotentialBase_Test, obj=sample.PotentialSampler): @classmethod def setup_class(cls): """Setup fixtures for testing.""" cls.potential = object() # register a unittest examples - class SubClassUnitTest(cls.obj, package="unittest"): + class SubClassUnitTest(cls.obj, key="unittest"): def __call__(self, n, *, frame=None, random=None, **kwargs): # Get preferred frames frame = self._preferred_frame_resolve(frame) @@ -49,25 +47,29 @@ def __call__(self, n, *, frame=None, random=None, **kwargs): random = np.random.default_rng(random) # return - return coord.SkyCoord( + sample = coord.SkyCoord( coord.ICRS( ra=random.uniform(size=n) * u.deg, dec=2 * random.uniform(size=n) * u.deg, ), ).transform_to(frame) + sample.mass = np.ones(n) + sample.potential = cls.potential + + return sample cls.SubClassUnitTest = SubClassUnitTest - # make instance. It depends + # make instance. It depends. if cls.obj is sample.PotentialSampler: - cls.inst = cls.obj(cls.potential, package="unittest") + cls.inst = cls.obj(cls.potential, key="unittest") # /def @classmethod def teardown_class(cls): """Teardown fixtures for testing.""" - sample.SAMPLER_REGISTRY.pop(unittest, None) + cls.SubClassUnitTest._registry.pop("unittest", None) # /def @@ -87,32 +89,32 @@ class SubClass1(self.obj): # register a new try: - class SubClass1(self.obj, package="pytest"): + class SubClass1(self.obj, key="pytest"): pass except Exception: pass finally: - sample.SAMPLER_REGISTRY.pop(pytest, None) + sample.SAMPLER_REGISTRY.pop("pytest", None) # ------------------------------- # error when already in registry try: # registered - class SubClass1(self.obj, package="pytest"): + class SubClass1(self.obj, key="pytest"): pass # doing it again raises error with pytest.raises(KeyError): - class SubClass1(self.obj, package="pytest"): + class SubClass1(self.obj, key="pytest"): pass except Exception: pass finally: # cleanup - sample.SAMPLER_REGISTRY.pop(pytest, None) + sample.SAMPLER_REGISTRY.pop("pytest", None) # /def @@ -129,13 +131,13 @@ def test__registry(self): super().test__registry() # ------------------------------- - assert isinstance(self.obj._registry, MappingProxyType) + assert isinstance(self.obj._registry, dict) # The unittest is already registered, so can # test for that. - assert unittest in self.obj._registry.keys() + assert "unittest" in self.obj._registry.keys() assert self.SubClassUnitTest in self.obj._registry.values() - assert self.obj._registry[unittest] is self.SubClassUnitTest + assert self.obj._registry["unittest"] is self.SubClassUnitTest # /def @@ -148,7 +150,7 @@ def test___class_getitem__(self): # ------------------------------- # test a specific item in the registry - assert self.obj[unittest] is self.SubClassUnitTest + assert self.obj["unittest"] is self.SubClassUnitTest # /def @@ -183,18 +185,17 @@ def test___new__(self): self.obj(self.potential) assert ( - "PotentialSampler has no registered sampler for package: " - "" + "PotentialSampler has no registered sampler for key: builtins" ) in str(e.value) # --------------- # with return_specific_class - package, klass = tuple(self.obj._registry.items())[0] + key, klass = tuple(self.obj._registry.items())[0] msamp = self.obj( self.potential, - package=package, + key=key, return_specific_class=True, ) @@ -208,11 +209,11 @@ def test___new__(self): # --------------- # as wrapper class - package, klass = tuple(self.obj._registry.items())[0] + key, klass = tuple(self.obj._registry.items())[0] msamp = self.obj( self.potential, - package=package, + key=key, return_specific_class=False, ) @@ -228,12 +229,12 @@ def test___new__(self): else: # never hit in Test_PotentialSampler, only in subs # --------------- - # Can't have the "package" argument + # Can't have the "key" argument with pytest.raises(ValueError) as e: - self.obj(self.potential, package="not None") + self.obj(self.potential, key="not None") - assert "Can't specify 'package'" in str(e.value) + assert "Can't specify 'key'" in str(e.value) # --------------- # warning @@ -241,7 +242,7 @@ def test___new__(self): with pytest.warns(UserWarning): self.obj( self.potential, - package=None, + key=None, return_specific_class=True, ) @@ -298,7 +299,7 @@ def test___call__(self): def test_call_parametrize(self, n, frame, kwargs): """Parametrized call tests.""" res = self.inst(n, frame=frame, **kwargs) - assert res.__class__ == coord.SkyCoord + assert isinstance(res, coord.SkyCoord) # /def @@ -415,34 +416,15 @@ def test__preferred_frame_resolve(self): # /def + ################################################################# + # Usage Tests + # /class # ------------------------------------------------------------------- -# class PotentialSamplerSubClassTests(Test_PotentialSampler): - -# @classmethod -# def setup_class(cls): -# """Setup fixtures for testing.""" -# cls.potential = object() - -# # cls.inst = cls.obj(potential, package="GaussianMeasurementErrorSampler") - -# # /def - -# @classmethod -# def teardown_class(cls): -# """Teardown fixtures for testing.""" -# pass - -# # /def - -# ################################################################# -# # Method Tests - -# # /def ############################################################################## # END diff --git a/discO/extern/__init__.py b/discO/extern/__init__.py deleted file mode 100644 index eb12e752..00000000 --- a/discO/extern/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -# see LICENSE.rst - -"""**DOCSTRING**.""" - - -__all__ = [] - - -############################################################################## -# IMPORTS - -# PROJECT-SPECIFIC -from . import agama -from .agama import * - -__all__ += agama.__all__ - -############################################################################## -# END diff --git a/discO/extern/agama/__init__.py b/discO/extern/agama/__init__.py deleted file mode 100644 index c26eac23..00000000 --- a/discO/extern/agama/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- -# see LICENSE.rst - -"""AGAMA interface.""" - -__all__ = [] - - -############################################################################## -# IMPORTS - -# PROJECT-SPECIFIC -from . import fitter -from .fitter import * # noqa: F401, F403 - -__all__ += fitter.__all__ - - -############################################################################## -# END diff --git a/discO/extern/agama/fitter.py b/discO/extern/agama/fitter.py deleted file mode 100644 index e8ab9acf..00000000 --- a/discO/extern/agama/fitter.py +++ /dev/null @@ -1,68 +0,0 @@ -# -*- coding: utf-8 -*- - -"""**DOCSTRING**.""" - -__all__ = [ - "AGAMAPotentialFitter", -] - - -############################################################################## -# IMPORTS - -# THIRD PARTY -import astropy.coordinates as coord - -# PROJECT-SPECIFIC -from discO.common import CoordinateType, SkyCoordType -from discO.core.fitter import PotentialFitter - -############################################################################## -# CODE -############################################################################## - - -class AGAMAPotentialFitter(PotentialFitter, package="agama"): - """Fit a set of particles""" - - # FIXME! these are specific to multipole - def __init__( - self, - pot_type="Multipole", - symmetry="a", - gridsizeR=20, - lmax=2, - **kwargs, - ): - import agama - - self._fitter = agama.Potential - self._kwargs = { - "type": pot_type, - "symmetry": symmetry, - "gridsizeR": gridsizeR, - "lmax": lmax, - **kwargs, - } - - # /defs - - def __call__(self, c: CoordinateType) -> SkyCoordType: - """Fit Potential given particles.""" - - position = c.represent_as(coord.CartesianRepresentation).xyz.T - # TODO! velocities - mass = c.mass # TODO! what if don't have? - - particles = (position, mass) - - return self._fitter(particles=particles, **self._kwargs) - - # /def - - -# /class - - -############################################################################## -# END diff --git a/discO/plugin/agama/__init__.py b/discO/plugin/agama/__init__.py index 2ac5a81b..134b2e24 100644 --- a/discO/plugin/agama/__init__.py +++ b/discO/plugin/agama/__init__.py @@ -10,10 +10,12 @@ # IMPORTS # PROJECT-SPECIFIC -from . import sample +from . import fitter, sample +from .fitter import * # noqa: F401, F403 from .sample import * # noqa: F401, F403 __all__ += sample.__all__ +__all__ += fitter.__all__ ############################################################################## diff --git a/discO/plugin/agama/fitter.py b/discO/plugin/agama/fitter.py new file mode 100644 index 00000000..6b8f6979 --- /dev/null +++ b/discO/plugin/agama/fitter.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +"""Fit a potential to data with :mod:`~agama`.""" + +__all__ = [ + "AGAMAPotentialFitter", + "AGAMAMultipolePotentialFitter", +] + + +############################################################################## +# IMPORTS + +# BUILT-IN +import typing as T +import warnings +from types import MappingProxyType + +# THIRD PARTY +import agama +import astropy.coordinates as coord + +# PROJECT-SPECIFIC +from discO.common import CoordinateType, SkyCoordType +from discO.core.fitter import PotentialFitter + +############################################################################## +# PARAMETERS + +AGAMA_FITTER_REGISTRY = dict() # package : samplers + +############################################################################## +# CODE +############################################################################## + + +class AGAMAPotentialFitter(PotentialFitter, key="agama"): + """Fit a set of particles""" + + ####################################################### + # On the class + + _registry = AGAMA_FITTER_REGISTRY + + ################################################################# + # On the instance + + def __new__( + cls, + *, + pot_type: T.Optional[str] = None, + return_specific_class: bool = False, + **kwargs, + ): + self = super().__new__(cls, agama.Potential) + + # The class AGAMAPotentialFitter is a wrapper for anything in its + # registry If directly instantiating a AGAMAPotentialFitter (not + # subclass) we must also instantiate the appropriate subclass. Error + # if can't find. + if cls is AGAMAPotentialFitter: + + if pot_type not in cls._registry: + raise ValueError( + "PotentialFitter has no registered fitter for `pot_type`: " + f"{pot_type}", + ) + + # from registry. Registered in __init_subclass__ + instance = cls._registry[pot_type](**kwargs) + + if return_specific_class: + return instance + + self._instance = instance + + elif pot_type is not None: + raise ValueError( + "Can't specify 'pot_type' on PotentialFitter subclasses.", + ) + + elif return_specific_class is not False: + warnings.warn("Ignoring argument `return_specific_class`") + + return self + + # /def + + def __init__( + self, + pot_type: T.Optional[str] = None, + symmetry: str = "a", + **kwargs, + ): + if pot_type is None: + raise ValueError("must specify a `pot_type`") + + if self.__class__ is AGAMAPotentialFitter: + self._kwargs = MappingProxyType(self._instance._kwargs) + else: + self._kwargs = { + "type": pot_type, + "symmetry": symmetry, + **kwargs, + } + + # /defs + + ####################################################### + # Fitting + + def __call__(self, c: CoordinateType) -> SkyCoordType: + """Fit Potential given particles.""" + + position = c.represent_as(coord.CartesianRepresentation).xyz.T + # TODO! velocities + mass = c.mass # TODO! what if don't have? have as parameter? + + particles = (position, mass) + + return self._fitter(particles=particles, **self._kwargs) + + # /def + + +# /class + + +##################################################################### + + +class AGAMAMultipolePotentialFitter(AGAMAPotentialFitter, key="multipole"): + """Fit a set of particles with a Multipole expansion.""" + + def __init__( + self, + symmetry: str = "a", + gridsizeR: int = 20, + lmax: int = 2, + **kwargs, + ): + kwargs.pop("pot_type", None) # clear from kwargs + super().__init__( + pot_type="Multipole", + symmetry=symmetry, + gridsizeR=gridsizeR, + lmax=lmax, + **kwargs, + ) + + # /def + + +# /class + + +############################################################################## +# END diff --git a/discO/plugin/agama/sample.py b/discO/plugin/agama/sample.py index 39c80451..9ad0707e 100644 --- a/discO/plugin/agama/sample.py +++ b/discO/plugin/agama/sample.py @@ -28,7 +28,7 @@ ############################################################################## -class AGAMAPotentialSampler(PotentialSampler, package="agama"): +class AGAMAPotentialSampler(PotentialSampler, key="agama"): """Sample a :mod:`~agama` Potential. Parameters @@ -70,7 +70,7 @@ def __call__( if np.shape(pos)[1] == 6: pos, _ = pos[:, :3], pos[:, 3:] # TODO: vel else: - # vel = None # TODO + # vel = None # TODO! pass # TODO get agama units ! diff --git a/discO/plugin/agama/tests/__init__.py b/discO/plugin/agama/tests/__init__.py index 1a2b2d1e..db6343ca 100644 --- a/discO/plugin/agama/tests/__init__.py +++ b/discO/plugin/agama/tests/__init__.py @@ -5,6 +5,7 @@ __all__ = [ "sample_tests", + "fitter_tests", ] @@ -12,6 +13,7 @@ # IMPORTS # PROJECT-SPECIFIC +from . import test_fitter as fitter_tests from . import test_sample as sample_tests ############################################################################## diff --git a/discO/plugin/agama/tests/test_fitter.py b/discO/plugin/agama/tests/test_fitter.py new file mode 100644 index 00000000..feb659a6 --- /dev/null +++ b/discO/plugin/agama/tests/test_fitter.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- + +"""Testing :mod:`~discO.plugin.agama.fitter`.""" + +__all__ = [ + "Test_AGAMAPotentialFitter", + "Test_AGAMAMultipolePotentialFitter", +] + + +############################################################################## +# IMPORTS + +# THIRD PARTY +import agama +import pytest + +# PROJECT-SPECIFIC +from discO.core.tests.test_fitter import ( + Test_PotentialFitter as PotentialFitterTester, +) +from discO.plugin.agama import fitter + +############################################################################## +# TESTS +############################################################################## + + +class Test_AGAMAPotentialFitter( + PotentialFitterTester, + obj=fitter.AGAMAPotentialFitter, +): + @classmethod + def setup_class(cls): + """Setup fixtures for testing.""" + cls.potential = agama.Potential + + # register a unittest examples + class SubClassUnitTest(cls.obj, key="unittest"): + def __init__( + self, + symmetry="a", + **kwargs, + ): + kwargs.pop("pot_type", None) # clear from kwargs + super().__init__( + pot_type="Multipole", + symmetry="a", + gridsizeR=20, + lmax=2, + **kwargs, + ) + + # /defs + + cls.SubClassUnitTest = SubClassUnitTest + + # make instance. It depends. + if cls.obj is fitter.AGAMAPotentialFitter: + cls.inst = cls.obj(pot_type="unittest", symmetry="a") + + # /def + + ################################################################# + # Method Tests + + def test___new__(self): + """Test method ``__new__``. + + This is a wrapper class that acts differently when instantiating + a MeasurementErrorSampler than one of it's subclasses. + + """ + # there are no tests on super + # super().test___new__() + + # -------------------------- + if self.obj is fitter.AGAMAPotentialFitter: + + # -------------------------- + # for object not in registry + + with pytest.raises(ValueError) as e: + self.obj(pot_type=None) + + assert ( + "PotentialFitter has no registered fitter for `pot_type`: None" + ) in str(e.value) + + # --------------- + # with return_specific_class + + klass = self.obj._registry["unittest"] + + msamp = self.obj(pot_type="unittest", return_specific_class=True) + + # test class type + assert isinstance(msamp, klass) + assert isinstance(msamp, self.obj) + + # test inputs + assert msamp._fitter == self.potential + + # --------------- + # as wrapper class + + klass = self.obj._registry["unittest"] + + msamp = self.obj(pot_type="unittest", return_specific_class=False) + + # test class type + assert not isinstance(msamp, klass) + assert isinstance(msamp, self.obj) + assert isinstance(msamp._instance, klass) + + # test inputs + assert msamp._fitter == self.potential + + # -------------------------- + else: # never hit in Test_PotentialSampler, only in subs + + pot_type = tuple(self.obj._registry.keys())[0] + + # --------------- + # Can't have the "key" argument + + with pytest.raises(ValueError) as e: + self.obj(pot_type=pot_type, key="not None") + + assert "Can't specify 'pot_type'" in str(e.value) + + # --------------- + # warning + + with pytest.warns(UserWarning): + self.obj( + key=None, + return_specific_class=True, + ) + + # --------------- + # AOK + + msamp = self.obj() + + assert self.obj is not fitter.PotentialFitter + assert isinstance(msamp, self.obj) + assert isinstance(msamp, fitter.PotentialFitter) + assert not hasattr(msamp, "_instance") + assert msamp._fitter == self.potential + + # /def + + # ------------------------------- + + +# /class + + +# ------------------------------------------------------------------- + + +class Test_AGAMAMultipolePotentialFitter( + Test_AGAMAPotentialFitter, + obj=fitter.AGAMAMultipolePotentialFitter, +): + @classmethod + def setup_class(cls): + """Setup fixtures for testing.""" + super().setup_class() + cls.inst = cls.obj(symmetry="a") + + # /def + + +# /class + + +############################################################################## +# END diff --git a/discO/plugin/agama/tests/test_sample.py b/discO/plugin/agama/tests/test_sample.py index 06b1a934..2286dda8 100644 --- a/discO/plugin/agama/tests/test_sample.py +++ b/discO/plugin/agama/tests/test_sample.py @@ -15,7 +15,9 @@ import pytest # PROJECT-SPECIFIC -from discO.core.tests.test_sample import Test_PotentialSampler +from discO.core.tests.test_sample import ( + Test_PotentialSampler as PotentialSamplerTester, +) from discO.plugin.agama import sample ############################################################################## @@ -24,7 +26,7 @@ class Test_AGAMAPotentialSampler( - Test_PotentialSampler, + PotentialSamplerTester, obj=sample.AGAMAPotentialSampler, ): @classmethod diff --git a/discO/plugin/galpy/sample.py b/discO/plugin/galpy/sample.py index 23ef8ab3..15131a11 100644 --- a/discO/plugin/galpy/sample.py +++ b/discO/plugin/galpy/sample.py @@ -25,7 +25,7 @@ ############################################################################## -class GalpyPotentialSampler(PotentialSampler, package="galpy"): +class GalpyPotentialSampler(PotentialSampler, key="galpy"): """Sample a :mod:`~galpy` Potential. Parameters @@ -92,23 +92,6 @@ def __call__( # /def - # def sample_at_c(self, c=None, n=1, frame=None, **kargs): - # if c is None: - # R, z, phi = None, None, None - - # else: - # rep = c.represent_as(coord.CylindricalRepresentation) - # R, z, phi = rep.rho, rep.z, rep.phi - - # orbits = self._sampler.sample( - # R=R, z=z, phi=phi, n=n, return_orbit=True, - # ) - # samples = orbits.SkyCoord().transform_to(frame) - - # return samples - - # # /def - # /class diff --git a/discO/plugin/galpy/tests/.galpyrc b/discO/plugin/galpy/tests/.galpyrc new file mode 100644 index 00000000..3e41baf5 --- /dev/null +++ b/discO/plugin/galpy/tests/.galpyrc @@ -0,0 +1,13 @@ +[normalization] +ro = 8. +vo = 220. + +[astropy] +astropy-units = True +astropy-coords = True + +[plot] +seaborn-bovy-defaults = False + +[warnings] +verbose = False diff --git a/discO/plugin/galpy/tests/test_sample.py b/discO/plugin/galpy/tests/test_sample.py index c102942c..e78a69f6 100644 --- a/discO/plugin/galpy/tests/test_sample.py +++ b/discO/plugin/galpy/tests/test_sample.py @@ -41,6 +41,7 @@ def setup_class(cls): hernquist_pot = HernquistPotential(amp=cls.mass) hernquist_pot.turn_physical_on() # force units cls.potential = isotropicHernquistdf(hernquist_pot) + cls.potential.turn_physical_on() cls.inst = cls.obj(cls.potential) @@ -73,7 +74,16 @@ def test_call_parametrize(self, n, frame, kwargs): assert res.potential == self.potential assert len(res.mass) == n - assert np.isclose(res.mass.sum(), self.mass.to_value(u.solMass)) + + got = res.mass.sum() + if hasattr(got, "unit"): + got = got.to_value(u.solMass) + + expected = self.mass + if hasattr(expected, "unit"): + expected = expected.to_value(u.solMass) + + assert np.isclose(got, expected) # TODO! value tests when https://github.com/jobovy/galpy/pull/443 # assert np.allclose(res.ra.deg, [126.10132346, 214.92637031]) diff --git a/docs/examples/sampling.ipynb b/docs/examples/sampling.ipynb index c5e61926..2411903a 100644 --- a/docs/examples/sampling.ipynb +++ b/docs/examples/sampling.ipynb @@ -77,7 +77,16 @@ "from galpy import potential as gpot\n", "\n", "# PROJECT-SPECIFIC\n", - "from discO import GaussianMeasurementErrorSampler, PotentialSampler, conf" + "from discO import (\n", + " GaussianMeasurementErrorSampler,\n", + " PotentialFitter,\n", + " PotentialSampler,\n", + " conf,\n", + ")\n", + "from discO.plugin.agama.fitter import (\n", + " AGAMAMultipolePotentialFitter,\n", + " AGAMAPotentialFitter,\n", + ")" ] }, { @@ -124,9 +133,21 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'agama': discO.plugin.agama.sample.AGAMAPotentialSampler,\n", + " 'galpy': discO.plugin.galpy.sample.GalpyPotentialSampler}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "hernquist_pot = gpot.HernquistPotential(amp=mass, a=r0)" + "PotentialSampler._registry" ] }, { @@ -137,7 +158,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 6, @@ -146,6 +167,7 @@ } ], "source": [ + "hernquist_pot = gpot.HernquistPotential(amp=mass, a=r0)\n", "sampler = PotentialSampler(gdf.isotropicHernquistdf(hernquist_pot))\n", "sampler" ] @@ -187,19 +209,19 @@ "data": { "text/plain": [ "" + " [[(-3.11262307e+00, -5.0451043 , -98.6587091 ),\n", + " (-3.38050841e-01, 0.07845362, 267.57994542)],\n", + " [( 4.24479960e-04, 0.15739561, 269.0456028 ),\n", + " (-2.22841128e+00, -2.55825835, -286.27947762)],\n", + " [(-4.86802296e-01, -1.84616346, 131.7489185 ),\n", + " (-3.30005209e-01, -0.86782519, 178.30272477)]]>" ] }, "execution_count": 8, @@ -222,12 +244,12 @@ "text/plain": [ ", galcen_distance=8.122 kpc, galcen_v_sun=(12.9, 245.6, 7.78) km / s, z_sun=20.8 pc, roll=0.0 deg): (x, y, z) in kpc\n", - " [[(27.09722772, -7.99428837, 3.58566112),\n", - " (30.13989564, 15.66672266, -42.96944503)],\n", - " [( 7.81060648, -10.51244897, -32.21253614),\n", - " (27.04664526, -7.98764378, 3.64108865)],\n", - " [(30.28927575, 15.75782264, -42.94008277),\n", - " ( 7.90968944, -10.28750205, -32.3145178 )]]>" + " [[(27.2533559 , -7.99571288, 3.62182835),\n", + " (30.2343536 , 15.84875145, -43.02157771)],\n", + " [( 7.75435465, -10.41796935, -32.17488259),\n", + " (27.07607157, -7.78650818, 3.57517316)],\n", + " [(30.14487454, 15.79686217, -42.93728097),\n", + " ( 7.56685117, -10.24302387, -32.36044107)]]>" ] }, "execution_count": 9, @@ -243,13 +265,6 @@ "meas(samps)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -282,7 +297,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 11, @@ -343,27 +358,50 @@ "metadata": {}, "outputs": [ { - "ename": "AttributeError", - "evalue": "'PotentialSampler' object has no attribute 'resample'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msampler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresample\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mniter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0marray\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'PotentialSampler' object has no attribute 'resample'" - ] + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "array = sampler.resample(niter=20, n=3)\n", - "array" + "array = sampler.sample(n=10, niter=20)\n", + "array[:1, :3]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "meas = GaussianMeasurementErrorSampler(c_err=0.1)\n", "\n", @@ -371,6 +409,72 @@ "meas(sample)" ] }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "PotentialFitter[(\"agama\", \"multipole\")] == PotentialFitter[\"agama\"][\n", + " \"multipole\"\n", + "] == AGAMAPotentialFitter[\"multipole\"]\n", + "\n", + "fitter = PotentialFitter[(\"agama\", \"multipole\")]()\n", + "fitter" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fit = AGAMAMultipolePotentialFitter(symmetry=\"a\")(sample)\n", + "fit" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(20,)" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fitter.fit(array).shape" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -387,7 +491,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -406,7 +510,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [