Skip to content
54 changes: 35 additions & 19 deletions colibri/analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
import jax.lax.linalg as jlinalg
import numpy as np
import scipy.special as special

Expand Down Expand Up @@ -83,7 +84,7 @@ def analytic_evidence_uniform_prior(sol_covmat, sol_mean, max_logl, a_vec, b_vec

@check_pdf_model_is_linear
def analytic_fit(
central_inv_covmat_index,
central_covmat_index,
_pred_data,
pdf_model,
analytic_settings,
Expand All @@ -102,8 +103,8 @@ def analytic_fit(

Parameters
----------
central_inv_covmat_index: commondata_utils.CentralInvCovmatIndex
dataclass containing central values and inverse covmat.
central_covmat_index: commondata_utils.CentralCovmatIndex
dataclass containing central values and covariance matrix.

_pred_data: @jax.jit CompiledFunction
Prediction function for the fit.
Expand Down Expand Up @@ -141,23 +142,37 @@ def analytic_fit(
intercept = pred_and_pdf(jnp.zeros(len(parameters)), fast_kernel_arrays)[0]

# Construct the analytic solution
central_values = central_inv_covmat_index.central_values
inv_covmat = central_inv_covmat_index.inv_covmat
central_values = central_covmat_index.central_values
covmat = central_covmat_index.covmat

# Solve chi2 analytically for the mean
Y = central_values - intercept
Sigma = inv_covmat
X = predictions.T - intercept[:, None]

# * Check that covmat is positive definite
if jnp.any(jla.eigh(X.T @ Sigma @ X)[0] <= 0.0):
t0 = time.time()

# Cholesky factorization: S = L L^T
# upper False means that we want the lower triangular matrix L
L = jla.cholesky(covmat, upper=False)

# Whiten the problem: Y' = L^-1 Y, X' = L^-1 X
Y_tilde = jlinalg.triangular_solve(L, Y, left_side=True, lower=True)
X_tilde = jlinalg.triangular_solve(L, X, left_side=True, lower=True)

if jnp.any(jla.eigh(X_tilde.T @ X_tilde)[0] <= 0.0):
raise ValueError(
"The obtained covariance matrix for the analytic solution is not positive definite."
)

t0 = time.time()
sol_mean = jla.inv(X.T @ Sigma @ X) @ X.T @ Sigma @ Y
sol_covmat = jla.inv(X.T @ Sigma @ X)
# Compute QR decomposition of X_tilde for numerical stability in the inversion
Q, R = jla.qr(X_tilde)

# NOTE: R is upper triangular in QR decomposition, so we need to set lower=False
sol_mean = jlinalg.triangular_solve(R, Q.T @ Y_tilde, left_side=True, lower=False)

I_R = jnp.eye(R.shape[0])
R_inv = jlinalg.triangular_solve(R, I_R, left_side=True, lower=False)
sol_covmat = R_inv @ R_inv.T

key = jax.random.PRNGKey(analytic_settings["sampling_seed"])

Expand All @@ -168,8 +183,6 @@ def analytic_fit(
sol_covmat,
shape=(analytic_settings["full_sample_size"],),
)
t1 = time.time()
log.info("ANALYTIC SAMPLING RUNTIME: %f s" % (t1 - t0))

# Compute the evidence
# This is the log of the evidence, which is the log of the integral of the likelihood
Expand Down Expand Up @@ -222,8 +235,8 @@ def analytic_fit(

gaussian_integral = jnp.log(jnp.sqrt(jla.det(2 * jnp.pi * sol_covmat)))
log_prior = jnp.log(1 / prior_width).sum()
# Compute maximum log likelihood
min_chi2 = (Y - X @ sol_mean) @ Sigma @ (Y - X @ sol_mean)
# Compute maximum log likelihood in the whitened basis
min_chi2 = (Y_tilde - X_tilde @ sol_mean).T @ (Y_tilde - X_tilde @ sol_mean)
# Compute the log likelihood
max_logl = -0.5 * min_chi2

Expand All @@ -244,12 +257,12 @@ def analytic_fit(
min_chi2 = -2 * max_logl
log.info(f"Minimum chi2 = {min_chi2}")

BIC = min_chi2 + sol_covmat.shape[0] * np.log(Sigma.shape[0])
BIC = min_chi2 + sol_covmat.shape[0] * np.log(covmat.shape[0])
AIC = min_chi2 + 2 * sol_covmat.shape[0]

# Compute average chi2
diffs = Y[:, None] - X @ full_samples.T
avg_chi2 = jnp.einsum("ij,jk,ki->i", diffs.T, Sigma, diffs).mean()
# Compute average chi2 (in whitened basis)
diffs = Y_tilde[:, None] - X_tilde @ full_samples.T
avg_chi2 = jnp.mean(jnp.sum(diffs**2, axis=0))

log.info(f"Average chi2 = {avg_chi2}")

Expand All @@ -260,6 +273,9 @@ def analytic_fit(
# Resample the posterior for PDF set
samples = full_samples[: analytic_settings["n_posterior_samples"]]

t1 = time.time()
log.info("ANALYTIC SAMPLING RUNTIME: %f s" % (t1 - t0))

return AnalyticFit(
analytic_specs=analytic_settings,
resampled_posterior=samples,
Expand Down
17 changes: 1 addition & 16 deletions colibri/commondata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jla

from colibri.theory_predictions import make_pred_dataset
from colibri.core import CentralCovmatIndex, CentralInvCovmatIndex
from colibri.core import CentralCovmatIndex


def experimental_commondata_tuple(data):
Expand Down Expand Up @@ -195,17 +194,3 @@ def pseudodata_central_covmat_index(
covariance matrix for a Monte Carlo fit.
"""
return central_covmat_index(commondata_tuple, data_generation_covariance_matrix)


def central_inv_covmat_index(central_covmat_index):
"""
Given a CentralCovmatIndex dataclass, compute the inverse
of the covariance matrix and store the relevant data into
CentralInvCovmatIndex dataclass.
"""
inv_covmat = jla.inv(central_covmat_index.covmat)
return CentralInvCovmatIndex(
central_values=central_covmat_index.central_values,
central_values_idx=central_covmat_index.central_values_idx,
inv_covmat=inv_covmat,
)
10 changes: 0 additions & 10 deletions colibri/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,6 @@ def to_dict(self):
return asdict(self)


@dataclass(frozen=True)
class CentralInvCovmatIndex:
central_values: jnp.array
inv_covmat: jnp.array
central_values_idx: jnp.array

def to_dict(self):
return asdict(self)


@dataclass(frozen=True)
class BayesianPrior:
prior_transform: Callable
Expand Down
7 changes: 1 addition & 6 deletions colibri/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,12 @@ def wmin_param(params):
"""


MOCK_CENTRAL_INV_COVMAT_INDEX = Mock()
MOCK_CENTRAL_INV_COVMAT_INDEX.central_values = jnp.ones(TEST_N_DATA)
MOCK_CENTRAL_INV_COVMAT_INDEX.inv_covmat = jnp.eye(TEST_N_DATA)
MOCK_CENTRAL_INV_COVMAT_INDEX.central_values_idx = jnp.arange(TEST_N_DATA)

MOCK_CENTRAL_COVMAT_INDEX = Mock()
MOCK_CENTRAL_COVMAT_INDEX.central_values = jnp.ones(TEST_N_DATA)
MOCK_CENTRAL_COVMAT_INDEX.covmat = jnp.eye(TEST_N_DATA)
MOCK_CENTRAL_COVMAT_INDEX.central_values_idx = jnp.arange(TEST_N_DATA)
"""
Mock instance of Central Inverse covmat index object.
Mock instance of Central covmat index object.
"""


Expand Down
12 changes: 6 additions & 6 deletions colibri/tests/test_analytic_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from colibri.analytic_fit import AnalyticFit, analytic_fit, run_analytic_fit
from colibri.core import PriorSettings
from colibri.tests.conftest import (
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
MOCK_PDF_MODEL,
TEST_FK_ARRAYS,
TEST_FORWARD_MAP_DIS,
Expand Down Expand Up @@ -46,7 +46,7 @@ def test_analytic_fit_flat_direction():
with pytest.raises(ValueError):
# Run the analytic fit and make sure that the Value Error is raised
analytic_fit(
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
_pred_data,
MOCK_PDF_MODEL,
analytic_settings,
Expand All @@ -69,7 +69,7 @@ def test_analytic_fit(caplog):

# Run the analytic fit
result = analytic_fit(
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
_pred_data,
MOCK_PDF_MODEL,
analytic_settings,
Expand All @@ -91,7 +91,7 @@ def test_analytic_fit(caplog):
# Run the analytic fit
with caplog.at_level(logging.ERROR): # Set the log level to ERROR
result_2 = analytic_fit(
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
_pred_data,
MOCK_PDF_MODEL,
analytic_settings,
Expand Down Expand Up @@ -129,7 +129,7 @@ def test_analytic_fit_different_priors(caplog):

# Run the analytic fit
result = analytic_fit(
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
_pred_data,
MOCK_PDF_MODEL,
analytic_settings,
Expand All @@ -155,7 +155,7 @@ def test_analytic_fit_different_priors(caplog):

# Run the analytic fit with custom uniform prior
result = analytic_fit(
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
_pred_data,
MOCK_PDF_MODEL,
analytic_settings,
Expand Down
21 changes: 0 additions & 21 deletions colibri/tests/test_commondata_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

import jax.numpy as jnp
import jax.scipy.linalg as jla
import pandas as pd
from nnpdf_data.coredata import CommonData
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -115,23 +114,3 @@ def test_level1_commondata_tuple():
reference_level1_central_values["data"].values,
current_level1_central_values[0].central_values,
)


def test_central_inv_covmat_index():
"""
Tests that the central_inv_covmat_index object is produced correctly.
"""
cci = colibriAPI.central_covmat_index(**{**TEST_DATASETS, **T0_PDFSET})

cici = colibriAPI.central_inv_covmat_index(**{**TEST_DATASETS, **T0_PDFSET})

# check that central_inv_covmat_index computes inverse covariance matrix correctly
assert_allclose(cici.inv_covmat, jla.inv(cci.covmat))

# check that central values and indices are the same
assert_allclose(cici.central_values, cci.central_values)
assert_allclose(cici.central_values_idx, cci.central_values_idx)

# check that the to_dict method works as expected
cici_dict = cici.to_dict()
assert isinstance(cici_dict, dict)
6 changes: 3 additions & 3 deletions colibri/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from colibri.api import API as cAPI
from colibri.tests.conftest import (
MOCK_CENTRAL_INV_COVMAT_INDEX,
MOCK_CENTRAL_COVMAT_INDEX,
MOCK_PDF_MODEL,
TEST_DATASET,
TEST_DATASET_HAD,
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_likelihood_float_type(
):

_pred_data = lambda x, fks: jnp.ones(
len(MOCK_CENTRAL_INV_COVMAT_INDEX.central_values)
len(MOCK_CENTRAL_COVMAT_INDEX.central_values)
) # Mock _pred_data
FIT_XGRID = jnp.linspace(0, 1, 10) # Mock FIT_XGRID
output_path = tmp_path
Expand All @@ -355,7 +355,7 @@ def test_likelihood_float_type(
FIT_XGRID=FIT_XGRID,
bayesian_prior=mock_bayesian_prior,
output_path=output_path,
central_inv_covmat_index=MOCK_CENTRAL_INV_COVMAT_INDEX,
central_covmat_index=MOCK_CENTRAL_COVMAT_INDEX,
fast_kernel_arrays=fast_kernel_arrays,
)

Expand Down
8 changes: 4 additions & 4 deletions colibri/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def likelihood_float_type(
FIT_XGRID,
bayesian_prior,
output_path,
central_inv_covmat_index,
central_covmat_index,
fast_kernel_arrays,
):
"""
Expand All @@ -306,8 +306,8 @@ def likelihood_float_type(

loss_function = chi2

central_values = central_inv_covmat_index.central_values
inv_covmat = central_inv_covmat_index.inv_covmat
central_values = central_covmat_index.central_values
covmat = central_covmat_index.covmat

pred_and_pdf = pdf_model.pred_and_pdf_func(FIT_XGRID, forward_map=_pred_data)

Expand All @@ -319,7 +319,7 @@ def log_likelihood(params, central_values, inv_covmat, fast_kernel_arrays):
jax.random.uniform(jax.random.PRNGKey(0), shape=(len(pdf_model.param_names),))
)

dtype = log_likelihood(params, central_values, inv_covmat, fast_kernel_arrays).dtype
dtype = log_likelihood(params, central_values, covmat, fast_kernel_arrays).dtype

# save the dtype to the output path
with open(output_path / "dtype.txt", "w") as file:
Expand Down
Loading