Skip to content

Commit

Permalink
type-fixes + enforcements + contracts for analytical solutions (#58)
Browse files Browse the repository at this point in the history
* fix typo

* change contracts for analytical sols (return covariances for sol)

* fix type hints

* linting and style checks

* import sorting
  • Loading branch information
mathematicalmichael authored Jul 5, 2022
1 parent 19782bd commit 7732468
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 43 deletions.
50 changes: 44 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,60 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip install --upgrade pip
pip install --upgrade wheel setuptools setuptools_scm
- name: Inspect version info
run: |
python setup.py --version
git describe --dirty --tags --long --match "*[0-9]*"
- name: Test pip install syntax (without wheels)
- name: Test pip install to site-packages
run: |
pip install .
pip uninstall -y mud
- name: Install dependencies
- name: Test pip install local
run: |
pip install --upgrade pip
pip install --upgrade wheel setuptools setuptools_scm
- name: Test install with wheels
pip install -e .
pip uninstall -y mud
- name: Test build
run: |
python setup.py sdist bdist_wheel
pip uninstall -y mud
style:
name: Enforce style
strategy:
matrix:
python-version: ["3.10"]
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v2
with:
fetch-depth: 1

- name: setup
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}

- name: install
run: |
pip install --upgrade pip
pip install -e .[dev]
- name: linting
run: flake8 .

- name: imports
run: isort -c .

- name: typing
run: |
mypy src/mud/
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
# All configuration values have a default; values that are commented out
# serve to show the default.

import os
import sys
import inspect
import os
import shutil
import sys

# -- Path setup --------------------------------------------------------------

Expand Down
8 changes: 5 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ exclude =

[options.extras_require]
dev =
pytest
pytest-cov
black
coverage
coveralls
flake8
black
isort
mypy
pre-commit
pytest
pytest-cov

pub =
setuptools
Expand Down
6 changes: 4 additions & 2 deletions src/mud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

if sys.version_info[:2] >= (3, 8):
# TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
from importlib.metadata import PackageNotFoundError, version # pragma: no cover
from importlib.metadata import PackageNotFoundError # pragma: no cover
from importlib.metadata import version # pragma: no covern
else:
from importlib_metadata import PackageNotFoundError, version # pragma: no cover
from importlib_metadata import PackageNotFoundError # pragma: no cover
from importlib_metadata import version # pragma: no cover

try:
# Change here if project is renamed and does not equal the package name
Expand Down
2 changes: 0 additions & 2 deletions src/mud/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import Callable, List, Optional, Union

import numpy as np

# from numpy.typing import ArrayLike
from matplotlib import pyplot as plt # type: ignore
from scipy.stats import distributions as dist # type: ignore
from scipy.stats import gaussian_kde as gkde # type: ignore
Expand Down
5 changes: 3 additions & 2 deletions src/mud/examples.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import pyplot as plt # type: ignore
from scipy.stats import distributions as ds # type: ignore

from mud.base import DensityProblem, IterativeLinearProblem
from mud.funs import wme
from mud.util import std_from_equipment
from scipy.stats import distributions as ds


def rotation_map(qnum=10, tol=0.1, b=None, ref_param=None, seed=None):
Expand Down
38 changes: 22 additions & 16 deletions src/mud/funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import sys

import numpy as np
from scipy.stats import distributions as dists # type: ignore

from mud import __version__
from mud.base import BayesProblem, DensityProblem
from scipy.stats import distributions as dists

__author__ = "Mathematical Michael"
__copyright__ = "Mathematical Michael"
Expand Down Expand Up @@ -169,7 +170,7 @@ def check_args(A, b, y, mean, cov, data_cov):
return ravel, z, mean, cov, data_cov


def mud_sol(A, b, y=None, mean=None, cov=None, data_cov=None, return_pred=False):
def mud_sol(A, b, y=None, mean=None, cov=None, data_cov=None):
"""
For SWE problem, we are inverting N(0,1).
This is the default value for `data_cov`.
Expand All @@ -183,10 +184,7 @@ def mud_sol(A, b, y=None, mean=None, cov=None, data_cov=None, return_pred=False)
# When y was passed as a 1d-array, we flatten the coefficients.
mud_point = mud_point.ravel()

if return_pred:
return mud_point, update
else:
return mud_point
return mud_point


def updated_cov(X, init_cov=None, data_cov=None):
Expand Down Expand Up @@ -239,7 +237,7 @@ def updated_cov(X, init_cov=None, data_cov=None):
return up_cov


def mud_sol_alt(A, b, y=None, mean=None, cov=None, data_cov=None, return_pred=False):
def mud_sol_with_cov(A, b, y=None, mean=None, cov=None, data_cov=None):
"""
Doesn't use R directly, uses new equations.
This presents the equation as a rank-k update
Expand All @@ -254,13 +252,10 @@ def mud_sol_alt(A, b, y=None, mean=None, cov=None, data_cov=None, return_pred=Fa
# When y was passed as a 1d-array, we flatten the coefficients.
mud_point = mud_point.ravel()

if return_pred:
return mud_point, update
else:
return mud_point
return mud_point, up_cov


def map_sol(A, b, y=None, mean=None, cov=None, data_cov=None, w=1, return_pred=False):
def map_sol(A, b, y=None, mean=None, cov=None, data_cov=None, w=1):
ravel, z, mean, cov, data_cov = check_args(A, b, y, mean, cov, data_cov)
inv = np.linalg.inv
post_cov = inv(A.T @ inv(data_cov) @ A + w * inv(cov))
Expand All @@ -271,10 +266,21 @@ def map_sol(A, b, y=None, mean=None, cov=None, data_cov=None, w=1, return_pred=F
# When y was passed as a 1d-array, we flatten the coefficients.
map_point = map_point.ravel()

if return_pred:
return map_point, update
else:
return map_point
return map_point


def map_sol_with_cov(A, b, y=None, mean=None, cov=None, data_cov=None, w=1):
ravel, z, mean, cov, data_cov = check_args(A, b, y, mean, cov, data_cov)
inv = np.linalg.inv
post_cov = inv(A.T @ inv(data_cov) @ A + w * inv(cov))
update = post_cov @ A.T @ inv(data_cov)
map_point = mean + update @ z

if ravel:
# When y was passed as a 1d-array, we flatten the coefficients.
map_point = map_point.ravel()

return map_point, post_cov


def performEpoch(A, b, y, initial_mean, initial_cov, data_cov=None, idx=None):
Expand Down
2 changes: 1 addition & 1 deletion src/mud/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import pyplot as plt # type: ignore

from mud.util import null_space

Expand Down
28 changes: 21 additions & 7 deletions src/mud/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
from numpy.typing import ArrayLike
from scipy.special import erfinv
from scipy.special import erfinv # type: ignore


def std_from_equipment(tolerance=0.1, probability=0.95):
Expand All @@ -16,7 +15,11 @@ def std_from_equipment(tolerance=0.1, probability=0.95):
return standard_deviation


def transform_linear_map(operator, data, std):
def transform_linear_map(
operator: np.ndarray,
data: Union[np.ndarray, List[float], Tuple[float]],
std: Union[np.ndarray, float, List[float], Tuple[float]],
):
"""
Takes a linear map `operator` of size (len(data), dim_input)
or (1, dim_input) for repeated observations, along with
Expand Down Expand Up @@ -72,7 +75,18 @@ def transform_linear_map(operator, data, std):
return A, b


def transform_linear_setup(operator_list, data_list, std_list):
def transform_linear_setup(
operator_list: List[np.ndarray],
data_list: Union[List[np.ndarray], Tuple[np.ndarray]],
std_list: Union[
float,
np.ndarray,
List[float],
Tuple[float],
Tuple[Tuple[float]],
List[List[float]],
],
):
if isinstance(std_list, (float, int)):
std_list = [std_list] * len(data_list)
# repeat process for multiple quantities of interest
Expand All @@ -85,7 +99,7 @@ def transform_linear_setup(operator_list, data_list, std_list):
return np.vstack(operators), np.vstack(datas)


def null_space(A, rcond=None):
def null_space(A: np.ndarray, rcond: Optional[float] = None):
"""
Construct an orthonormal basis for the null space of A using SVD
Expand Down Expand Up @@ -209,6 +223,6 @@ def make_2d_normal_mesh(N: int = 50, window: int = 1):
return (X, Y, XX)


def set_shape(array: ArrayLike, shape: Union[List, Tuple] = (1, -1)):
def set_shape(array: np.ndarray, shape: Union[List, Tuple] = (1, -1)) -> np.ndarray:
"""Resizes inputs if they are one-dimensional."""
return array.reshape(shape) if array.ndim < 2 else array
7 changes: 5 additions & 2 deletions tests/test_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ def test_solutions_with_orthogonal_map(self):
# Act
y = A @ t + b
sol_mud = mdf.mud_sol(A, b, y, cov=c)
sol_alt = mdf.mud_sol_alt(A, b, y, cov=c)
sol_mud_alt, updated_cov = mdf.mud_sol_with_cov(A, b, y, cov=c)
sol_map = mdf.map_sol(A, b, y, cov=c)
sol_map_alt, posterior_cov = mdf.map_sol_with_cov(A, b, y, cov=c)

err_mud = sol_mud - t
err_alt = sol_alt - t
err_alt = sol_mud_alt - t
err_map = sol_map - t

# Assert
assert np.linalg.norm(sol_map - sol_map_alt) < 1e-12
assert np.linalg.norm(sol_mud - sol_mud_alt) < 1e-6
assert np.linalg.norm(err_mud) < 1e-6
assert np.linalg.norm(err_alt) < 1e-6
assert np.linalg.norm(err_mud) < np.linalg.norm(err_map)
Expand Down

0 comments on commit 7732468

Please sign in to comment.