-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Arnaud De-Mattia
committed
Nov 11, 2022
0 parents
commit 5eb2af5
Showing
45 changed files
with
6,479 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
**/dust | ||
doc/_build/ | ||
**/__pycache__ | ||
**/pytest_cache | ||
**/*.egg | ||
**/*.egg-info | ||
**/.astropy | ||
**/.cache | ||
**/.config | ||
**/.wget* | ||
**/dist | ||
**/eggs | ||
**/build | ||
**/.ipynb_checkpoints | ||
**/bak/ | ||
**/_tests | ||
**/_catalog | ||
**/_plots | ||
**/_results | ||
**/tests/_* | ||
**/tests/TestSimpleLikelihood* | ||
|
||
# Unit test / coverage reports | ||
**/.coverage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
BSD 3-Clause License | ||
|
||
Copyright (c) 2021, cosmodesi | ||
All rights reserved. | ||
|
||
Redistribution and use in source and binary forms, with or without | ||
modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this | ||
list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, | ||
this list of conditions and the following disclaimer in the documentation | ||
and/or other materials provided with the distribution. | ||
|
||
3. Neither the name of the copyright holder nor the names of its | ||
contributors may be used to endorse or promote products derived from | ||
this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | ||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | ||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | ||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | ||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | ||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | ||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# desilike | ||
|
||
WARNING: this is ongoing work! | ||
|
||
**desilike** is an attempt to provide a common framework for writing DESI likelihoods, | ||
that can be imported in common cosmological inference codes (Cobaya, CosmoSIS, MontePython). | ||
|
||
Example notebooks presenting most use cases are provided in directory nb/. | ||
|
||
## Documentation | ||
|
||
Documentation is hosted on Read the Docs, [desilike docs](https://desilike.readthedocs.io/). | ||
|
||
## Requirements | ||
|
||
Only strict requirements are: | ||
|
||
- numpy | ||
- scipy | ||
|
||
## Installation | ||
|
||
### pip | ||
|
||
Simply run: | ||
``` | ||
python -m pip install git+https://github.com/adematti/desilike | ||
``` | ||
|
||
### git | ||
|
||
First: | ||
``` | ||
git clone https://github.com/adematti/desilike.git | ||
``` | ||
To install the code: | ||
``` | ||
python setup.py install --user | ||
``` | ||
Or in development mode (any change to Python code will take place immediately): | ||
``` | ||
python setup.py develop --user | ||
``` | ||
|
||
## License | ||
|
||
**desilike** is free software distributed under a BSD3 license. For details see the [LICENSE](https://github.com/adematti/desilike/blob/main/LICENSE). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._version import __version__ | ||
from .utils import setup_logging | ||
from .parameter import Parameter, ParameterPrior, ParameterCollection, ParameterArray, ParameterValues |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = '1.0.0' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,277 @@ | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
|
||
from .utils import BaseClass, NamespaceDict, Monitor | ||
from .io import BaseConfig | ||
from .parameter import ParameterCollection, ParameterArray, ParameterValues | ||
|
||
|
||
namespace_delimiter = '.' | ||
|
||
|
||
class PipelineError(Exception): | ||
|
||
"""Exception raised when issue with pipeline.""" | ||
|
||
|
||
class Info(NamespaceDict): | ||
|
||
"""Namespace/dictionary holding calculator static attributes.""" | ||
|
||
|
||
class BasePipeline(BaseClass): | ||
|
||
def __init__(self, calculator): | ||
self.calculators = [calculator] | ||
|
||
def callback(calculator): | ||
for require in calculator.runtime_info.requires.values(): | ||
if require not in self.calculators: | ||
self.calculators.append(require) | ||
callback(require) | ||
|
||
callback(self.calculators[0]) | ||
self.calculators = self.calculators[::-1] | ||
self._derived = None | ||
|
||
def _get_params(self, params=None, quiet=False): | ||
params_from_calculator = {} | ||
ref_params = ParameterCollection(params) | ||
params = ParameterCollection() | ||
for calculator in self.calculators: | ||
for iparam, param in enumerate(calculator.runtime_info.full_params): | ||
if param in ref_params: | ||
calculator.runtime_info.full_params[iparam] = param = ref_params[param] | ||
if not quiet and param in params: | ||
if param.derived and param.fixed: | ||
msg = 'Derived parameter {} of {} is already derived in {}.'.format(param, calculator, params_from_calculator[param.name]) | ||
if param.basename not in calculator.runtime_info.derived_auto and param.basename not in params_from_calculator[param.name].runtime_info.derived_auto: | ||
raise PipelineError(msg) | ||
elif self.mpicomm.rank == 0: | ||
self.log_warning(msg) | ||
elif param != params[param]: | ||
raise PipelineError('Parameter {} of {} is different from that of {}.'.format(param, calculator, params_from_calculator[param.name])) | ||
params_from_calculator[param.name] = calculator | ||
params.set(param) | ||
for param in ref_params: | ||
if param not in params: | ||
raise PipelineError('Parameter {} is not used by any calculator'.format(param)) | ||
self._derived = None | ||
return params | ||
|
||
@property | ||
def params(self): | ||
return self._get_params() | ||
|
||
@property | ||
def param_values(self): | ||
if getattr(self, '_param_values', None) is None: | ||
self._param_values = {param.name: param.value for param in self.params} | ||
return self._param_values | ||
|
||
def eval_params(self, params): | ||
toret = {} | ||
all_params = {**self.param_values, **params} | ||
for param in all_params: | ||
try: | ||
toret[param] = self.params[param].eval(**all_params) | ||
except KeyError: | ||
pass | ||
return toret | ||
|
||
def calculate(self, **params): | ||
to_calculate = self._derived is None | ||
self.param_values.update(params) | ||
params = self.eval_params(params) | ||
for calculator in self.calculators: # start by first calculator, and by the last one | ||
for param in calculator.runtime_info.full_params: | ||
value = params.get(param.name, None) | ||
if value is not None and param.basename in calculator.runtime_info.param_values and value != calculator.runtime_info.param_values[param.basename]: | ||
calculator.runtime_info.param_values[param.basename] = value | ||
to_calculate = True | ||
result = None | ||
if to_calculate: | ||
self.derived = ParameterValues() | ||
for calculator in self.calculators: | ||
result = calculator.runtime_info.calculate() | ||
for param in calculator.runtime_info.full_params: | ||
if param.depends: | ||
self.derived.set(ParameterArray(np.asarray(params[param.name]), param=param)) | ||
self.derived.update(calculator.runtime_info.derived) | ||
return result | ||
|
||
|
||
class RuntimeInfo(BaseClass): | ||
|
||
"""Information about calculator name, requirements, parameters values at a given step, etc.""" | ||
|
||
def __init__(self, calculator): | ||
""" | ||
initialize :class:`RuntimeInfo`. | ||
Parameters | ||
---------- | ||
calculator : BaseCalculator | ||
The calculator this :class:`RuntimeInfo` instance is attached to. | ||
""" | ||
self.calculator = calculator | ||
self.monitor = Monitor() | ||
self.required_by = set() | ||
|
||
@property | ||
def requires(self): | ||
if getattr(self, '_requires', None) is None: | ||
self._requires = {} | ||
for name, value in self.calculator.__dict__.items(): | ||
if isinstance(value, BaseCalculator): | ||
self._requires[name] = value | ||
self.requires = self._requires | ||
return self._requires | ||
|
||
@requires.setter | ||
def requires(self, requires): | ||
self._requires = dict(requires) | ||
for name, require in self._requires.items(): | ||
require.runtime_info.required_by.add((self.calculator, name)) | ||
self._pipeline = None | ||
|
||
@property | ||
def pipeline(self): | ||
if getattr(self, '_pipeline', None) is None: | ||
self._pipeline = BasePipeline(self.calculator) | ||
return self._pipeline | ||
|
||
@property | ||
def full_params(self): | ||
if getattr(self, '_full_params', None) is None: | ||
self._full_params = self.calculator.params | ||
return self._full_params | ||
|
||
@full_params.setter | ||
def full_params(self, full_params): | ||
self._full_params = full_params | ||
self._base_params = self._solved_params = self._derived_params = self._param_values = self._pipeline = None | ||
|
||
@property | ||
def base_params(self): | ||
if getattr(self, '_base_params', None) is None: | ||
self._base_params = {param.basename: param for param in self.full_params} | ||
return self._base_params | ||
|
||
@property | ||
def solved_params(self): | ||
if getattr(self, '_solved_params', None) is None: | ||
self._solved_params = self.full_params.select(solved=True) | ||
return self._solved_params | ||
|
||
@property | ||
def derived_params(self): | ||
if getattr(self, '_derived_params', None) is None: | ||
self._derived_params = self.full_params.select(derived=True, solved=False, depends={}) | ||
return self._derived_params | ||
|
||
@property | ||
def derived(self): | ||
if getattr(self, '_derived', None) is None: | ||
self._derived = ParameterValues() | ||
if self.derived_params: | ||
state = self.calculator.__getstate__() | ||
for param in self.derived_params: | ||
name = param.basename | ||
if name in state: value = state[name] | ||
else: value = getattr(self.calculator, name) | ||
self._derived.set(ParameterArray(np.asarray(value), param=param), output=True) | ||
return self._derived | ||
|
||
def calculate(self, **params): | ||
self.param_values.update(**params) | ||
self.monitor.start() | ||
try: | ||
self.result = self.calculator.calculate(**self.param_values) | ||
except Exception as exc: | ||
raise PipelineError('Error in method calculate of {}'.format(self.calculator)) from exc | ||
self.monitor.stop() | ||
return self.result | ||
|
||
@property | ||
def param_values(self): | ||
if getattr(self, '_param_values', None) is None: | ||
self._param_values = {param.basename: param.value for param in self.full_params if (not param.drop) and (param.depends or (not param.derived) or param.solved)} | ||
return self._param_values | ||
|
||
def __getstate__(self): | ||
"""Return this class state dictionary.""" | ||
return self.__dict__.copy() | ||
|
||
def update(self, *args, **kwargs): | ||
"""Update with provided :class:`RuntimeInfo` instance of dict.""" | ||
state = self.__getstate__() | ||
if len(args) == 1 and isinstance(args[0], self.__class__): | ||
state.update(args[0].__getstate__()) | ||
elif len(args): | ||
raise ValueError('Unrecognized arguments {}'.format(args)) | ||
state.update(kwargs) | ||
for name, value in state.items(): | ||
setattr(self, name, value) # this is to properly update properties with setters | ||
|
||
def clone(self, *args, **kwargs): | ||
"""Clone, i.e. copy and update.""" | ||
new = self.copy() | ||
new.update(*args, **kwargs) | ||
return new | ||
|
||
def deepcopy(self): | ||
import copy | ||
new = self.copy() | ||
new.full_params = copy.deepcopy(self.full_params) | ||
return new | ||
|
||
|
||
class BaseCalculator(BaseClass): | ||
|
||
def __getattr__(self, name): | ||
if name == 'runtime_info': | ||
self.initialize() | ||
return self.runtime_info | ||
return super(BaseCalculator, self).__getattribute__(name) | ||
|
||
def __setattr__(self, name, value): | ||
super(BaseCalculator, self).__setattr__(name, value) | ||
if 'runtime_info' in self.__dict__ and name in self.runtime_info.requires: | ||
self.runtime_info.requires[name] = value | ||
self.runtime_info.requires = self.runtime_info.requires | ||
for calculator, name in self.runtime_info.required_by: | ||
setattr(calculator.runtime_info, name, self) | ||
|
||
def __new__(cls, *args, **kwargs): | ||
from functools import wraps | ||
|
||
def initialize(func): | ||
@wraps(func) | ||
def wrapper(self, *args, **kwargs): | ||
func(self, *args, **kwargs) | ||
self.runtime_info = RuntimeInfo(self) | ||
return wrapper | ||
|
||
cls.initialize = initialize(cls.initialize) | ||
|
||
cls.info = Info(**getattr(cls, 'info', {})) | ||
cls.params = ParameterCollection() | ||
if hasattr(cls, 'config_fn'): | ||
dirname = os.path.dirname(sys.modules[cls.__module__].__file__) | ||
config = BaseConfig(os.path.join(dirname, cls.config_fn), index={'class': cls.__name__}) | ||
cls.info = Info(**{**config.get('info', {}), **cls.info}) | ||
params = ParameterCollection(config.get('params', {})) | ||
params.update(cls.params) | ||
cls.params = params | ||
init = config.get('init', {}) | ||
if init: kwargs = {**init, **kwargs} | ||
return super(BaseCalculator, cls).__new__(cls) | ||
|
||
def __call__(self, **params): | ||
return self.runtime_info.pipeline.calculate(**params) | ||
|
||
def __getstate__(self): | ||
return {} |
Oops, something went wrong.