diff --git a/flip/__init__.py b/flip/__init__.py index 297ba15..686fb34 100644 --- a/flip/__init__.py +++ b/flip/__init__.py @@ -2,7 +2,17 @@ import os +from flip.utils import create_log + +log = create_log() from . import covariance, fisher, fitter, gridding, likelihood, power_spectra, utils +try: + import jax + + jax.config.update("jax_enable_x64", True) +except: + log.add("Jax is not available, loading numpy and scipy instead") + __version__ = "1.0.0" __flip_dir_path__ = os.path.dirname(__file__) diff --git a/flip/covariance/__init__.py b/flip/covariance/__init__.py index 638da73..7ac1b86 100644 --- a/flip/covariance/__init__.py +++ b/flip/covariance/__init__.py @@ -1,4 +1,5 @@ """Init file of the flip.covariance package.""" from . import adamsblake17plane, adamsblake20, carreres23, lai22, ravouxcarreres, rcrk24 +from . import cov_utils from .covariance import CovMatrix diff --git a/flip/covariance/cov_utils.py b/flip/covariance/cov_utils.py index 27bceb3..5e70f97 100644 --- a/flip/covariance/cov_utils.py +++ b/flip/covariance/cov_utils.py @@ -231,6 +231,17 @@ def return_full_cov(cov): return full_cov +def return_flat_cov(cov): + variance_val = cov[0, 0] + flat_cov = cov[np.triu_indices_from(cov, k=1)] + flat_cov = np.insert(flat_cov, 0, variance_val) + return flat_cov + + +def return_flat_cross_cov(cov): + return cov.flatten() + + def return_full_cov_cross(cov, number_objects_g, number_objects_v): """ The return_full_cov_cross function takes in a covariance matrix and the number of objects in each band. @@ -245,7 +256,7 @@ def return_full_cov_cross(cov, number_objects_g, number_objects_v): The full covariance matrix """ - full_cov = cov[1:].reshape((number_objects_g, number_objects_v)) + full_cov = cov.reshape((number_objects_g, number_objects_v)) return full_cov diff --git a/flip/covariance/covariance.py b/flip/covariance/covariance.py index b82826b..2a1718e 100644 --- a/flip/covariance/covariance.py +++ b/flip/covariance/covariance.py @@ -4,12 +4,179 @@ import numpy as np -from flip.covariance import cov_utils from flip.utils import create_log +try: + import jax.numpy as jnp + from jax import jit + + jax_installed = True +except ImportError: + import numpy as jnp + + jax_installed = False +from flip.covariance import cov_utils + log = create_log() +def compute_covariance_sum_density( + coefficients_dict, + covariance_dict, + coefficients_dict_diagonal, + vector_err, + number_densities, + number_velocities, +): + covariance_sum = jnp.sum( + jnp.array( + [ + coefficients_dict["gg"][i] * cov + for i, cov in enumerate(covariance_dict["gg"]) + ] + ), + axis=0, + ) + covariance_sum += jnp.diag(coefficients_dict_diagonal["gg"] + vector_err**2) + + return covariance_sum + + +def compute_covariance_sum_velocity( + coefficients_dict, + covariance_dict, + coefficients_dict_diagonal, + vector_err, + number_densities, + number_velocities, +): + covariance_sum = jnp.sum( + jnp.array( + [ + coefficients_dict["vv"][i] * cov + for i, cov in enumerate(covariance_dict["vv"]) + ] + ), + axis=0, + ) + + covariance_sum += jnp.diag(coefficients_dict_diagonal["vv"] + vector_err**2) + + return covariance_sum + + +def compute_covariance_sum_density_velocity( + coefficients_dict, + covariance_dict, + coefficients_dict_diagonal, + vector_err, + number_densities, + number_velocities, +): + + density_err = vector_err[:number_densities] + velocity_err = vector_err[number_densities : number_densities + number_velocities] + + covariance_sum_gv = jnp.zeros((number_densities, number_velocities)) + covariance_sum_gg = jnp.sum( + jnp.array( + [ + coefficients_dict["gg"][i] * cov + for i, cov in enumerate(covariance_dict["gg"]) + ] + ), + axis=0, + ) + covariance_sum_gg += jnp.diag(coefficients_dict_diagonal["gg"] + density_err**2) + + covariance_sum_vv = jnp.sum( + jnp.array( + [ + coefficients_dict["vv"][i] * cov + for i, cov in enumerate(covariance_dict["vv"]) + ] + ), + axis=0, + ) + + covariance_sum_vv += jnp.diag(coefficients_dict_diagonal["vv"] + velocity_err**2) + + covariance_sum_vg = -covariance_sum_gv.T + + covariance_sum = jnp.block( + [ + [covariance_sum_gg, covariance_sum_gv], + [covariance_sum_vg, covariance_sum_vv], + ] + ) + return covariance_sum + + +def compute_covariance_sum_full( + coefficients_dict, + covariance_dict, + coefficients_dict_diagonal, + vector_err, + number_densities, + number_velocities, +): + + density_err = vector_err[:number_densities] + velocity_err = vector_err[number_densities : number_densities + number_velocities] + + covariance_sum_gv = jnp.sum( + jnp.array( + [ + coefficients_dict["gv"][i] * cov + for i, cov in enumerate(covariance_dict["gv"]) + ] + ), + axis=0, + ) + covariance_sum_gg = jnp.sum( + jnp.array( + [ + coefficients_dict["gg"][i] * cov + for i, cov in enumerate(covariance_dict["gg"]) + ] + ), + axis=0, + ) + covariance_sum_gg += jnp.diag(coefficients_dict_diagonal["gg"] + density_err**2) + + covariance_sum_vv = jnp.sum( + jnp.array( + [ + coefficients_dict["vv"][i] * cov + for i, cov in enumerate(covariance_dict["vv"]) + ] + ), + axis=0, + ) + + covariance_sum_vv += jnp.diag(coefficients_dict_diagonal["vv"] + velocity_err**2) + + covariance_sum_vg = -covariance_sum_gv.T + + covariance_sum = jnp.block( + [ + [covariance_sum_gg, covariance_sum_gv], + [covariance_sum_vg, covariance_sum_vv], + ] + ) + return covariance_sum + + +if jax_installed: + + compute_covariance_sum_density_jit = jit(compute_covariance_sum_density) + compute_covariance_sum_velocity_jit = jit(compute_covariance_sum_velocity) + compute_covariance_sum_density_velocity_jit = jit( + compute_covariance_sum_density_velocity + ) + compute_covariance_sum_full_jit = jit(compute_covariance_sum_full) + + class CovMatrix: def __init__( self, @@ -283,6 +450,7 @@ def compute_covariance_sum( self, parameter_values_dict, vector_err, + use_jit=False, ): """ The compute_covariance_sum function computes the sum of all covariance matrices @@ -312,79 +480,18 @@ def compute_covariance_sum( parameter_values_dict, ) - if self.model_type == "density": - covariance_sum = np.sum( - [ - coefficients_dict["gg"][i] * cov - for i, cov in enumerate(self.covariance_dict["gg"]) - ], - axis=0, - ) - covariance_sum += np.diag(coefficients_dict_diagonal["gg"] + vector_err**2) - - elif self.model_type == "velocity": - covariance_sum = np.sum( - [ - coefficients_dict["vv"][i] * cov - for i, cov in enumerate(self.covariance_dict["vv"]) - ], - axis=0, - ) - - covariance_sum += np.diag(coefficients_dict_diagonal["vv"] + vector_err**2) - - elif self.model_type in ["density_velocity", "full"]: - number_densities = self.number_densities - number_velocities = self.number_velocities - density_err = vector_err[:number_densities] - velocity_err = vector_err[ - number_densities : number_densities + number_velocities - ] - - if self.model_type == "density_velocity": - covariance_sum_gv = np.zeros((number_densities, number_velocities)) - elif self.model_type == "full": - covariance_sum_gv = np.sum( - [ - coefficients_dict["gv"][i] * cov - for i, cov in enumerate(self.covariance_dict["gv"]) - ], - axis=0, - ) - covariance_sum_gg = np.sum( - [ - coefficients_dict["gg"][i] * cov - for i, cov in enumerate(self.covariance_dict["gg"]) - ], - axis=0, - ) - covariance_sum_gg += np.diag( - coefficients_dict_diagonal["gg"] + density_err**2 - ) - - covariance_sum_vv = np.sum( - [ - coefficients_dict["vv"][i] * cov - for i, cov in enumerate(self.covariance_dict["vv"]) - ], - axis=0, - ) - - covariance_sum_vv += np.diag( - coefficients_dict_diagonal["vv"] + velocity_err**2 - ) - - covariance_sum_vg = -covariance_sum_gv.T - - covariance_sum = np.block( - [ - [covariance_sum_gg, covariance_sum_gv], - [covariance_sum_vg, covariance_sum_vv], - ] - ) - else: - log.add(f"Wrong model type in the loaded covariance.") - + covariance_sum_func = eval( + f"compute_covariance_sum_{self.model_type}" + + f"{'_jit' if jax_installed & use_jit else ''}" + ) + covariance_sum = covariance_sum_func( + coefficients_dict, + self.covariance_dict, + coefficients_dict_diagonal, + vector_err, + self.number_densities, + self.number_velocities, + ) return covariance_sum def compute_covariance_sum_eigenvalues( @@ -409,43 +516,80 @@ def compute_full_matrix(self): A dictionary with the full covariance matrices for each redshift bin """ - if self.full_matrix is False: - for key in ["gg", "vv", "gv"]: - if key in self.covariance_dict.keys(): - if key == "gg": - new_shape = ( - self.covariance_dict[key].shape[0], - self.number_densities, - self.number_densities, - ) - elif key == "gv": - new_shape = ( - self.covariance_dict[key].shape[0], - self.number_densities, - self.number_velocities, - ) - elif key == "vv": - new_shape = ( - self.covariance_dict[key].shape[0], - self.number_velocities, - self.number_velocities, - ) - new_cov = np.zeros(new_shape) - for i, _ in enumerate(self.covariance_dict[key]): - if key == "gv": - new_cov[i] = cov_utils.return_full_cov_cross( - self.covariance_dict[key][i], - self.number_densities, - self.number_velocities, - ) - else: - new_cov[i] = cov_utils.return_full_cov( - self.covariance_dict[key][i] - ) - self.covariance_dict[key] = new_cov + if self.full_matrix: + log.add("Full matrix already computed") + return + + for key in self.covariance_dict.keys(): + if key == "gg": + new_shape = ( + self.covariance_dict[key].shape[0], + self.number_densities, + self.number_densities, + ) + elif key == "gv": + new_shape = ( + self.covariance_dict[key].shape[0], + self.number_densities, + self.number_velocities, + ) + elif key == "vv": + new_shape = ( + self.covariance_dict[key].shape[0], + self.number_velocities, + self.number_velocities, + ) + else: + log.warning(f"{key} != 'gg', 'gv' or 'vv' was ignored") + continue + + new_cov = np.zeros(new_shape) + for i, _ in enumerate(self.covariance_dict[key]): + if key[0] != key[1]: + new_cov[i] = cov_utils.return_full_cov_cross( + self.covariance_dict[key][i], + self.number_densities, + self.number_velocities, + ) + else: + new_cov[i] = cov_utils.return_full_cov(self.covariance_dict[key][i]) + self.covariance_dict[key] = new_cov self.full_matrix = True + def compute_flat_matrix(self): + for key in self.covariance_dict.keys(): + if key == "gg": + new_shape = ( + self.covariance_dict[key].shape[0], + int(self.number_densities * (self.number_densities - 1) / 2) + 1, + ) + elif key == "gv": + new_shape = ( + self.covariance_dict[key].shape[0], + self.number_densities * self.number_velocities + 1, + ) + elif key == "vv": + new_shape = ( + self.covariance_dict[key].shape[0], + int(self.number_velocities * (self.number_velocities - 1) / 2) + 1, + ) + else: + log.warning(f"{key} != 'gg', 'gv' or 'vv' was ignored") + continue + + new_cov = np.zeros(new_shape) + for i, _ in enumerate(self.covariance_dict[key]): + if key == "gv": + new_cov[i] = cov_utils.return_flat_cross_cov( + self.covariance_dict[key][i], + ) + else: + new_cov[i] = cov_utils.return_flat_cov(self.covariance_dict[key][i]) + self.covariance_dict[key] = new_cov + + self.full_matrix = False + def write( self, filename, @@ -493,3 +637,97 @@ def write( elif file_format == "npz": np.savez(f"{filename}.npz", **class_attrs_dictionary) + + def mask(self, mask_vel=None, mask_dens=None): + + Ng = self.number_densities + Nv = self.number_velocities + + if mask_vel is None and mask_dens is None: + raise ValueError("No mask set") + + masked_cov_dic = {} + if mask_vel is not None: + if len(mask_vel) != self.number_velocities: + raise ValueError("Velocities mask size does not match vel cov size") + + Nv = np.sum(mask_vel) + cov_vv_mask = np.outer(mask_vel, mask_vel) + + if self.full_matrix: + masked_cov_dic["vv"] = np.array( + [ + cov[cov_vv_mask].reshape((Nv, Nv)) + for cov in self.covariance_dict["vv"] + ] + ) + else: + cov_vv_mask = cov_vv_mask[np.triu_indices(self.number_velocities, k=1)] + cov_vv_mask = np.insert(cov_vv_mask, 0, True) + + masked_cov_dic["vv"] = np.array( + [cov[cov_vv_mask] for cov in self.covariance_dict["vv"]] + ) + + if mask_dens is not None: + if len(mask_dens) != self.number_densities: + raise ValueError("Densities mask size does not match density cov size") + + Ng = np.sum(mask_dens) + cov_gg_mask = np.outer(mask_dens, mask_dens) + + if self.full_matrix: + masked_cov_dic["gg"] = np.array( + [ + cov[cov_gg_mask].reshape((Ng, Ng)) + for cov in self.covariance_dict["gg"] + ] + ) + else: + cov_gg_mask = cov_gg_mask[np.triu_indices(self.number_densities, k=1)] + cov_gg_mask = np.insert(cov_gg_mask, 0, True) + masked_cov_dic["gg"] = np.array( + [cov[cov_gg_mask] for cov in self.covariance_dict["gg"]] + ) + + if self.number_densities is not None and self.number_velocities is not None: + if mask_vel is None: + cov_gv_mask = np.outer( + mask_dens, np.ones(self.number_velocities, dtype="bool") + ) + elif mask_dens is None: + cov_gv_mask = np.outer( + np.ones(self.number_densities, dtype="bool"), mask_vel + ) + else: + cov_gv_mask = np.outer(mask_dens, mask_vel) + + if self.full_matrix: + masked_cov_dic["gv"] = np.array( + [ + cov[cov_gv_mask].reshape((Ng, Nv)) + for cov in self.covariance_dict["gv"] + ] + ) + else: + cov_gv_mask = cov_gv_mask.flatten() + masked_cov_dic["gv"] = np.array( + [cov[cov_gv_mask] for cov in self.covariance_dict["gv"]] + ) + + for k in self.covariance_dict: + if k not in masked_cov_dic: + masked_cov_dic[k] = self.covariance_dict[k] + + return CovMatrix( + model_name=self.model_name, + model_type=self.model_type, + los_definition=self.los_definition, + covariance_dict=masked_cov_dic, + full_matrix=self.full_matrix, + number_densities=np.sum(mask_dens), + number_velocities=np.sum(mask_vel), + redshift_dict=self.redshift_dict, + power_spectrum_amplitude_function=self.power_spectrum_amplitude_function, + variant=self.variant, + ) diff --git a/flip/covariance/generator.py b/flip/covariance/generator.py index da21546..51f2c96 100644 --- a/flip/covariance/generator.py +++ b/flip/covariance/generator.py @@ -4,7 +4,6 @@ import cosmoprimo import numpy as np from scipy import integrate -from scipy.interpolate import interp1d from scipy.special import spherical_jn from flip.covariance import cov_utils @@ -52,7 +51,7 @@ def correlation_integration(l, r, k, integrand): return (-1) ** (l % 2) * integrate.simpson(integrand, x=k) -def correlation_hankel(l, r, k, integrand, hankel_overhead_coefficient=2): +def correlation_hankel(l, r, k, integrand, hankel_overhead_coefficient=2, kmin=None): """ The correlation_hankel function is a wrapper for the cosmoprimo.fftlog.PowerToCorrelation function, which computes the correlation function from power spectrum using FFTLog (Hamilton 2000). @@ -77,9 +76,20 @@ def correlation_hankel(l, r, k, integrand, hankel_overhead_coefficient=2): Hankel.set_fft_engine("numpy") r_hankel, xi_hankel = Hankel(integrand) mask = r < np.min(r_hankel) * hankel_overhead_coefficient + if np.any(r > np.max(r_hankel)): + raise ValueError( + "Min pw spectrum k is too high, please take a lower one. Use kmin parameter to lower bound integration." + ) output = np.empty_like(r) output[mask] = correlation_integration(l, r[mask], k, integrand) - output[~mask] = (-1) ** (l % 2) * interp1d(r_hankel, xi_hankel)(r[~mask]) + output[~mask] = (-1) ** (l % 2) * np.interp(r[~mask], r_hankel, xi_hankel) + + # Regularization + if kmin is not None: + kreg = np.geomspace(np.min(k), kmin, int(len(k) / 10)) + integrand_reg = np.exp(np.interp(np.log(kreg), np.log(k), np.log(integrand))) + output -= correlation_integration(l, r, kreg, integrand_reg) + return output @@ -92,6 +102,7 @@ def coefficient_hankel( power_spectrum, coord, additional_parameters_values=None, + kmin=None, ): """ The coefficient_hankel function computes the covariance between two terms of a given model. @@ -128,7 +139,11 @@ def coefficient_hankel( coord[1], coord[2] ) hankel_ab_i_l_j = correlation_hankel( - l, coord[0], wavenumber, M_ab_i_l_j(wavenumber) * power_spectrum + l, + coord[0], + wavenumber, + M_ab_i_l_j(wavenumber) * power_spectrum, + kmin=kmin, ) cov_ab_i = cov_ab_i + N_ab_i_l_j * hankel_ab_i_l_j return cov_ab_i @@ -262,6 +277,7 @@ def compute_coeficient( additional_parameters_values=None, number_worker=8, hankel=True, + kmin=None, ): """ The compute_coeficient function computes the covariance matrix for a given model. @@ -305,8 +321,8 @@ def compute_coeficient( power_spectrum_list[index_power_spectrum][0], power_spectrum_list[index_power_spectrum][1], additional_parameters_values=additional_parameters_values, + kmin=kmin, ) - if number_worker == 1: for i, index in enumerate(term_index_list): loc = locals() @@ -324,23 +340,25 @@ def compute_coeficient( eval(f"map_async_{index}").get() ) - for i, index in enumerate(term_index_list): - if multi_index_model: - index_power_spectrum = int(index[0]) - else: - index_power_spectrum = i - variance_t = coefficient( - model_name, - covariance_type, - index, - lmax_list[i], - power_spectrum_list[index_power_spectrum][0], - power_spectrum_list[index_power_spectrum][1], - np.zeros((3, 1)), - additional_parameters_values=additional_parameters_values, - )[0] - - locals()[f"cov_{index}"] = np.insert(eval(f"cov_{index}"), 0, variance_t) + # In the case of autocorrelation, add the theoretical variance. + if covariance_type[0] == covariance_type[1]: + for i, index in enumerate(term_index_list): + if multi_index_model: + index_power_spectrum = int(index[0]) + else: + index_power_spectrum = i + variance_t = coefficient( + model_name, + covariance_type, + index, + lmax_list[i], + power_spectrum_list[index_power_spectrum][0], + power_spectrum_list[index_power_spectrum][1], + np.zeros((3, 1)), + additional_parameters_values=additional_parameters_values, + )[0] + + locals()[f"cov_{index}"] = np.insert(eval(f"cov_{index}"), 0, variance_t) loc = locals() return np.array( @@ -348,9 +366,6 @@ def compute_coeficient( ) -# CR - Need to remove variance at zero separation for vg - - def compute_cov( model_name, covariance_type, @@ -362,6 +377,7 @@ def compute_cov( number_worker=8, hankel=True, los_definition="bisector", + kmin=None, ): """ The compute_cov function computes the covariance matrix for a given model. @@ -403,6 +419,7 @@ def compute_cov( additional_parameters_values=additional_parameters_values, number_worker=number_worker, hankel=hankel, + kmin=kmin, ) return covariance @@ -462,6 +479,7 @@ def generate_covariance( number_worker=8, hankel=True, los_definition="bisector", + kmin=None, ): """ The generate_flip function computes the covariance matrix for a given model. @@ -501,6 +519,7 @@ def generate_covariance( number_worker=number_worker, hankel=hankel, los_definition=los_definition, + kmin=kmin, ) number_densities = len(coordinates_density[0]) else: @@ -518,6 +537,7 @@ def generate_covariance( number_worker=number_worker, hankel=hankel, los_definition=los_definition, + kmin=kmin, ) number_velocities = len(coordinates_velocity[0]) else: @@ -535,6 +555,7 @@ def generate_covariance( number_worker=number_worker, hankel=hankel, los_definition=los_definition, + kmin=kmin, ) redshift_dict = generate_redshift_dict( diff --git a/flip/fitter.py b/flip/fitter.py index 1a81f6a..b8fde3d 100644 --- a/flip/fitter.py +++ b/flip/fitter.py @@ -1,17 +1,28 @@ import abc import multiprocessing as mp import os +from contextlib import nullcontext import emcee import iminuit import numpy as np -import flip.likelihood as flik -from flip.covariance.covariance import CovMatrix +try: + from jax import grad as jax_grad + + jax_installed = True +except ImportError: + jax_installed = False + pass + + from flip.utils import create_log log = create_log() +import flip.likelihood as flik +from flip.covariance.covariance import CovMatrix + class BaseFitter(abc.ABC): def __init__( @@ -205,9 +216,14 @@ def init_from_covariance( parameter_dict[parameters]["value"] for parameters in parameter_dict ] + if jax_installed & likelihood.likelihood_properties["use_gradient"]: + grad = jax_grad(likelihood) + else: + grad = None minuit_fitter.minuit = iminuit.Minuit( likelihood, parameter_values, + grad=grad, name=likelihood.parameter_names, ) @@ -460,19 +476,25 @@ def __init__(self, likelihood, p0=None, backend_file=None): self.backend = None if backend_file is not None: + backend_file_exists = os.path.exists(backend_file) self.backend = emcee.backends.HDFBackend(backend_file) - if os.path.exists(backend_file): + if backend_file_exists: log.add( "File already exist" "Initial size: {0}".format(self.backend.iteration) ) + if self.backend.iteration == 0: + log.add("Backend file is empty, please delete it and relaunch") self._p0 = None self.nwalkers = self.backend.shape[0] else: - log.add("Create new file to store chains") + log.add("No file initialize, will create a new one") + self.backend_file_exists = backend_file_exists + else: + self.backend_file_exists = False def run_chains(self, nsteps, number_worker=1, progress=False): - with mp.Pool(number_worker) as pool: + with mp.Pool(number_worker) if number_worker != 1 else nullcontext() as pool: sampler = emcee.EnsembleSampler( self.nwalkers, self.ndim, @@ -480,7 +502,11 @@ def run_chains(self, nsteps, number_worker=1, progress=False): pool=pool, backend=self.backend, ) - sampler.run_mcmc(self.p0, nsteps, progress=progress) + sampler.run_mcmc( + self.p0, + nsteps, + progress=progress, + ) return sampler def run_chains_untilconv( @@ -491,8 +517,8 @@ def run_chains_untilconv( progress=False, ): """Run chains until reaching auto correlation convergence criteria.""" - tau = np.inf - with mp.Pool(number_worker) as pool: + old_tau = np.inf + with mp.Pool(number_worker) if number_worker != 1 else nullcontext() as pool: sampler = emcee.EnsembleSampler( self.nwalkers, self.ndim, @@ -500,16 +526,23 @@ def run_chains_untilconv( pool=pool, backend=self.backend, ) - for sample in sampler.sample( - self.p0, iterations=maxstep, progress=progress - ): - if sampler.iteration % 500 == 0: - # Compute tau - tau = sampler.get_autocorr_time(tol=0) - # Check convergence - converged = np.all(tau * 100 < sampler.iteration) - converged &= np.all(np.abs(old_tau - tau) / tau < tau_conv) - if converged: - break - old_tau = tau + if not self.backend_file_exists: + for _ in sampler.sample(self.p0, iterations=maxstep, progress=progress): + if sampler.iteration % 500 == 0: + # Compute tau + tau = sampler.get_autocorr_time(tol=0) + # Check convergence + converged = np.all(tau * 100 < sampler.iteration) + converged &= np.all(np.abs(old_tau - tau) / tau < tau_conv) + if converged: + break + old_tau = tau + else: + # If the file already exists run to max step. + sampler.run_mcmc( + None, + maxstep - self.backend.iteration, + progress=progress, + ) + return sampler diff --git a/flip/likelihood.py b/flip/likelihood.py index ee4de93..e3960f1 100644 --- a/flip/likelihood.py +++ b/flip/likelihood.py @@ -3,25 +3,65 @@ import numpy as np import scipy as sc +try: + import jax.numpy as jnp + import jax.scipy as jsc + from jax import jit + + jax_installed = True +except ImportError: + import numpy as jnp + import scipy as jsc + + jax_installed = False + from flip import vectors from flip.utils import create_log +# try: +# import jax, jaxlib +# import jax.numpy as jnp +# except ImportError: +# jax = None +# import numpy as jnp + + +# def use_jax(array): +# """Whether to use jax.numpy depending on whether array is jax's object""" +# return jax and isinstance(array, (jaxlib.xla_extension.DeviceArrayBase, jax.core.Tracer)) + + +# def np_jax(array): +# """Return numpy or jax.numpy depending on whether array is jax's object""" +# if use_jax(array): +# return jnp +# return np + +# CR - cool implementation - to test + log = create_log() + + _available_priors = ["gaussian", "positive", "uniform"] def log_likelihood_gaussian_inverse(vector, covariance_sum): - _, logdet = np.linalg.slogdet(covariance_sum) - inverse_covariance_sum = np.linalg.inv(covariance_sum) - chi2 = np.dot(vector, np.dot(inverse_covariance_sum, vector)) - return -0.5 * (vector.size * np.log(2 * np.pi) + logdet + chi2) + _, logdet = jnp.linalg.slogdet(covariance_sum) + inverse_covariance_sum = jnp.linalg.inv(covariance_sum) + chi2 = jnp.dot(vector, jnp.dot(inverse_covariance_sum, vector)) + return -0.5 * (vector.size * jnp.log(2 * np.pi) + logdet + chi2) def log_likelihood_gaussian_cholesky(vector, covariance_sum): - cholesky = sc.linalg.cho_factor(covariance_sum) - logdet = 2 * np.sum(np.log(np.diag(cholesky[0]))) - chi2 = np.dot(vector, sc.linalg.cho_solve(cholesky, vector)) - return -0.5 * (vector.size * np.log(2 * np.pi) + logdet + chi2) + cholesky = jsc.linalg.cho_factor(covariance_sum) + logdet = 2 * jnp.sum(jnp.log(jnp.diag(cholesky[0]))) + chi2 = jnp.dot(vector, jsc.linalg.cho_solve(cholesky, vector)) + return -0.5 * (vector.size * jnp.log(2 * np.pi) + logdet + chi2) + + +if jax_installed: + log_likelihood_gaussian_inverse_jit = jit(log_likelihood_gaussian_inverse) + log_likelihood_gaussian_cholesky_jit = jit(log_likelihood_gaussian_cholesky) def no_prior(x): @@ -32,6 +72,41 @@ def prior_sum(priors, x): return sum(prior(x) for prior in priors) +def interpolate_covariance_sum_1d( + interpolation_value_range, + interpolation_value, + covariance, + parameter_values_dict, + vector_error, +): + if np.isnan(interpolation_value): + return np.full_like( + covariance[0].compute_covariance_sum(parameter_values_dict, vector_error), + np.nan, + ) + upper_index_interpolation = jnp.searchsorted( + interpolation_value_range, interpolation_value + ) + covariance_sum_upper = covariance[upper_index_interpolation].compute_covariance_sum( + parameter_values_dict, vector_error + ) + + covariance_sum_lower = covariance[ + upper_index_interpolation - 1 + ].compute_covariance_sum(parameter_values_dict, vector_error) + + fraction_interpolation = ( + interpolation_value_range[upper_index_interpolation] - interpolation_value + ) / ( + interpolation_value_range[upper_index_interpolation] + - interpolation_value_range[upper_index_interpolation - 1] + ) + covariance_sum = ( + 1 - fraction_interpolation + ) * covariance_sum_upper + fraction_interpolation * covariance_sum_lower + return covariance_sum + + class BaseLikelihood(object): _default_likelihood_properties = { @@ -39,6 +114,8 @@ class BaseLikelihood(object): "velocity_type": "direct", "velocity_estimator": "full", "negative_log_likelihood": True, + "use_jit": False, + "use_gradient": False, } def __init__( @@ -117,16 +194,15 @@ def load_data_vector( density, density_error = vectors.load_density_vectors(self.data) if model_type == "density": - return density, density_error + vector, vector_error = density, density_error elif model_type == "velocity": - return velocity, velocity_error + vector, vector_error = velocity, velocity_error elif model_type in ["density_velocity", "full"]: - return ( - np.concatenate([density, velocity], axis=0), - np.concatenate([density_error, velocity_error], axis=0), - ) + vector = np.concatenate([density, velocity], axis=0) + vector_error = np.concatenate([density_error, velocity_error], axis=0) else: log.add(f"Wrong model type in the loaded covariance.") + return vector, vector_error def initialize_prior( self, @@ -144,17 +220,17 @@ def initialize_prior( ) elif prior_properties["type"].lower() == "gaussian": prior = GaussianPrior( - parameter_name=parameter_name.lower(), + parameter_name=parameter_name, prior_mean=prior_properties["mean"], prior_standard_deviation=prior_properties["standard_deviation"], ) elif prior_properties["type"].lower() == "positive": prior = PositivePrior( - parameter_name=parameter_name.lower(), + parameter_name=parameter_name, ) elif prior_properties["type"].lower() == "uniform": prior = UniformPrior( - parameter_name=parameter_name.lower(), + parameter_name=parameter_name, range=prior_properties["range"], ) priors.append(prior) @@ -192,17 +268,23 @@ def __call__(self, parameter_values): parameter_values_dict, ) covariance_sum = self.covariance.compute_covariance_sum( - parameter_values_dict, vector_error + parameter_values_dict, + vector_error, + use_jit=self.likelihood_properties["use_jit"], ) likelihood_function = eval( f"log_likelihood_gaussian_{self.likelihood_properties['inversion_method']}" + + f"{'_jit' if jax_installed and self.likelihood_properties['use_jit'] else ''}" ) prior_value = self.prior(parameter_values_dict) if self.likelihood_properties["negative_log_likelihood"]: - return -likelihood_function(vector, covariance_sum) - prior_value - - return likelihood_function(vector, covariance_sum) + prior_value + likelihood_value = ( + -likelihood_function(vector, covariance_sum) - prior_value + ) + else: + likelihood_value = likelihood_function(vector, covariance_sum) + prior_value + return likelihood_value class MultivariateGaussianLikelihoodInterpolate1D(BaseLikelihood): @@ -296,30 +378,27 @@ def __call__( parameter_values_dict, ) - covariance_sum_list = [ - self.covariance[i].compute_covariance_sum( - parameter_values_dict, vector_error - ) - for i in range(len(self.covariance)) - ] - - covariance_sum_interpolated = sc.interpolate.interp1d( + covariance_sum = interpolate_covariance_sum_1d( self.interpolation_value_range, - covariance_sum_list, - copy=False, - axis=0, + interpolation_value, + self.covariance, + parameter_values_dict, + vector_error, ) - covariance_sum = covariance_sum_interpolated(interpolation_value) - likelihood_function = eval( f"log_likelihood_gaussian_{self.likelihood_properties['inversion_method']}" + + f"{'_jit' if jax_installed and self.likelihood_properties['use_jit'] else ''}" ) prior_value = self.prior(parameter_values_dict) if self.likelihood_properties["negative_log_likelihood"]: - return -likelihood_function(vector, covariance_sum) - prior_value + likelihood_value = ( + -likelihood_function(vector, covariance_sum) - prior_value + ) + else: + likelihood_value = likelihood_function(vector, covariance_sum) + prior_value - return likelihood_function(vector, covariance_sum) + prior_value + return likelihood_value class MultivariateGaussianLikelihoodInterpolate2D(BaseLikelihood): diff --git a/flip/plot_utils.py b/flip/plot_utils.py index 0478dcc..515b79f 100644 --- a/flip/plot_utils.py +++ b/flip/plot_utils.py @@ -185,27 +185,28 @@ def plot_all_fits( fiducials=None, compute_fs8_from_beta=False, subset_plot=None, + remove_lower=None, + remove_higher=None, + plot=True, **kwargs, ): + + fit_to_plot, fit_name_to_plot = select_valid_fits( + fit_output, + subset_plot=subset_plot, + remove_lower=remove_lower, + remove_higher=remove_higher, + ) + figsize = utils.return_key(kwargs, "figsize", (10, 10)) - all_fit = glob.glob(os.path.join(fit_output, "*")) - fig, ax = plt.subplots(len(parameters), 1, figsize=figsize, sharex=True) - fit_names, param_dict, error_dict = [], {}, {} + if plot: + fig, ax = plt.subplots(len(parameters), 1, figsize=figsize, sharex=True) + param_dict, error_dict = {}, {} for j, param_name in enumerate(parameters): param_dict[param_name] = [] error_dict[param_name] = [] - for i, f in enumerate(all_fit): - if subset_plot is not None: - if subset_plot not in f: - continue - fit = pickle.load(open(f, "rb")) - if fit[3] is False: - continue - elif fit[4] is False: - continue - - fit_names.append(f) + for i, fit in enumerate(fit_to_plot): for j, param_name in enumerate(parameters): if (param_name == "fs8") & (compute_fs8_from_beta): param = fit[0]["beta_f"] * fit[0]["bs8"] @@ -222,24 +223,25 @@ def plot_all_fits( error = fit[2][param_name] param_dict[param_name].append(param) error_dict[param_name].append(error) - ax[j].errorbar( - i, - param, - error, - marker=".", - ls="None", - color="C1", - ) - - ax[j].set_ylabel(param_name, fontsize=18) + if plot: + ax[j].errorbar( + i, + param, + error, + marker=".", + ls="None", + color="C1", + ) - if fiducials is not None: - if fiducials[j] is not None: - ax[j].axhline(fiducials[j], ls=":", color="k") + ax[j].set_ylabel(param_name, fontsize=18) - ax[0].margins(x=0.005) - fig.tight_layout() - return fit_names, param_dict, error_dict + if fiducials is not None: + if fiducials[j] is not None: + ax[j].axhline(fiducials[j], ls=":", color="k") + if plot: + ax[0].margins(x=0.005) + fig.tight_layout() + return fit_name_to_plot, param_dict, error_dict def plot_all_mean_fits( @@ -248,26 +250,23 @@ def plot_all_mean_fits( fiducials=None, weighted_mean=True, compute_fs8_from_beta=False, + plot_std_error=False, + plot_error_bar_of_mean=True, subset_plot=None, + remove_lower=None, + remove_higher=None, + plot=True, **kwargs, ): - figsize = utils.return_key(kwargs, "figsize", (10, 10)) - all_fit = glob.glob(os.path.join(fit_output, "*")) + fit_to_plot, fit_name_to_plot = select_valid_fits( + fit_output, + subset_plot=subset_plot, + remove_lower=remove_lower, + remove_higher=remove_higher, + ) - fit_to_plot = [] - fit_name_to_plot = [] - for f in all_fit: - if subset_plot is not None: - if subset_plot not in f: - continue - fit = pickle.load(open(f, "rb")) - if fit[3] is False: - continue - elif fit[4] is False: - continue - fit_to_plot.append(fit) - fit_name_to_plot.append(f) + figsize = utils.return_key(kwargs, "figsize", (10, 10)) fit_prop = [] for i in range(len(fit_name_to_plot)): @@ -277,16 +276,17 @@ def plot_all_mean_fits( fit_prop = np.array(fit_prop) unique_fit_prop = np.sort(np.unique(fit_prop)) - - fig, ax = plt.subplots(len(parameters), 1, figsize=figsize, sharex=True) - fig2, ax2 = plt.subplots(len(parameters), 1, figsize=figsize, sharex=True) + if plot: + fig, ax = plt.subplots(len(parameters), 1, figsize=figsize, sharex=True) + fig2, ax2 = plt.subplots(len(parameters), 1, figsize=figsize, sharex=True) text = [] - mean_param_dict, mean_error_dict, error_mean_dict = {}, {}, {} + mean_param_dict, mean_error_dict, error_mean_dict, std_dict = {}, {}, {}, {} for j, param_name in enumerate(parameters): mean_param_dict[param_name] = [] mean_error_dict[param_name] = [] error_mean_dict[param_name] = [] + std_dict[param_name] = [] for i, fit_p in enumerate(unique_fit_prop): @@ -320,33 +320,88 @@ def plot_all_mean_fits( mean_param = np.mean(params) error_mean_param = np.mean(errors) / np.sqrt(len(mask[mask])) mean_error_param = np.mean(errors) + std_param = np.std(params) mean_param_dict[param_name].append(mean_param) mean_error_dict[param_name].append(mean_error_param) error_mean_dict[param_name].append(error_mean_param) + std_dict[param_name].append(std_param) + if plot: + if plot_std_error: + if plot_error_bar_of_mean: + error_plot = std_param / np.sqrt(len(mask[mask])) + else: + error_plot = std_param + else: + if plot_error_bar_of_mean: + error_plot = error_mean_param + else: + error_plot = mean_error_param + ax[j].errorbar( + i, mean_param, error_plot, marker=".", ls="None", color="C1" + ) - ax[j].errorbar( - i, mean_param, error_mean_param, marker=".", ls="None", color="C1" - ) - - ax[j].set_ylabel(param_name, fontsize=18) + ax[j].set_ylabel(param_name, fontsize=18) - ax2[j].plot(i, mean_error_param, marker=".", ls="None", color="C1") - ax2[j].set_ylabel(r"$\sigma$(" + param_name + ")", fontsize=18) + ax2[j].plot(i, error_plot, marker=".", ls="None", color="C1") + ax2[j].set_ylabel(r"$\sigma$(" + param_name + ")", fontsize=18) - if fiducials is not None: - if fiducials[j] is not None: - ax[j].axhline(fiducials[j], ls=":", color="k") + if fiducials is not None: + if fiducials[j] is not None: + ax[j].axhline(fiducials[j], ls=":", color="k") text.append(fit_p) j_index = np.arange(len(unique_fit_prop)) - ax[-1].set_xticks(j_index, np.array(text), rotation=90, fontsize=10) - ax[0].margins(x=0.005) - fig.tight_layout() + if plot: + ax[-1].set_xticks(j_index, np.array(text), rotation=90, fontsize=10) + ax[0].margins(x=0.005) + fig.tight_layout() + + ax2[-1].set_xticks(j_index, np.array(text), rotation=90, fontsize=10) + ax2[0].margins(x=0.005) + fig2.tight_layout() + + return ( + unique_fit_prop, + mean_param_dict, + mean_error_dict, + error_mean_dict, + std_dict, + ) - ax2[-1].set_xticks(j_index, np.array(text), rotation=90, fontsize=10) - ax2[0].margins(x=0.005) - fig2.tight_layout() - return unique_fit_prop, mean_param_dict, mean_error_dict, error_mean_dict +def select_valid_fits( + fit_output, + subset_plot=None, + remove_lower=None, + remove_higher=None, +): + all_fit = glob.glob(os.path.join(fit_output, "*")) + + fit_to_plot = [] + fit_name_to_plot = [] + for f in all_fit: + valid_fit = True + if subset_plot is not None: + for subset in subset_plot: + if subset not in f: + valid_fit = False + if valid_fit: + fit = pickle.load(open(f, "rb")) + if fit[3] is False: + valid_fit = False + elif fit[4] is False: + valid_fit = False + if remove_lower is not None: + for param in remove_lower.keys(): + if fit[0][param] < remove_lower[param]: + valid_fit = False + if remove_higher is not None: + for param in remove_higher.keys(): + if fit[0][param] > remove_higher[param]: + valid_fit = False + if valid_fit: + fit_to_plot.append(fit) + fit_name_to_plot.append(f) + return fit_to_plot, fit_name_to_plot diff --git a/flip/vectors.py b/flip/vectors.py index 6e232d1..5c9e003 100644 --- a/flip/vectors.py +++ b/flip/vectors.py @@ -3,6 +3,11 @@ from flip import utils from flip.utils import create_log +try: + import jax.numpy as jnp +except: + import numpy as jnp + log = create_log() _avail_velocity_type = ["direct", "scatter", "saltfit"] @@ -11,16 +16,16 @@ def load_density_error(data): if "density_error" in data.keys(): - density_error = data["density_error"] + density_error = jnp.array(data["density_error"]) else: log.add("""No density error in data, loading a null density error""") - density_error = np.zeros_like(data["density"]) + density_error = jnp.zeros_like(data["density"]) return density_error def load_density_vectors(data): if "density" in data.keys(): - density = data["density"] + density = jnp.array(data["density"]) density_error = load_density_error(data) else: raise ValueError("""The data does not contains a density field""") @@ -151,15 +156,15 @@ def load_velocity_vectors( def get_velocity_error_directly(data): if "velocity_error" in data.keys(): - velocity_error = data["velocity_error"] + velocity_error = jnp.array(data["velocity_error"]) else: log.add("""No velocity error in data, loading a null velocity error""") - velocity_error = np.zeros_like(data["velocity"]) + velocity_error = jnp.zeros_like(data["velocity"]) return velocity_error def get_velocity_directly(data): - velocity = data["velocity"] + velocity = jnp.array(data["velocity"]) velocity_error = get_velocity_error_directly(data) return velocity, velocity_error @@ -186,7 +191,7 @@ def get_velocity_from_scatter( parameter_values_dict, velocity_estimator, ): - velocity = data["velocity"] + velocity = jnp.array(data["velocity"]) velocity_error = get_velocity_error_from_scatter( data, @@ -216,7 +221,7 @@ def get_velocity_error_from_salt_fit( velocity_estimator, ) - velocity_error = redshift_dependence * np.sqrt(variance_mu) + velocity_error = redshift_dependence * jnp.sqrt(variance_mu) return velocity_error, redshift_dependence @@ -239,7 +244,9 @@ def get_velocity_from_salt_fit( parameter_values_dict["beta"], parameter_values_dict["M_0"], ) - muth = 5 * np.log10((1 + data["zobs"]) * data["rcom_zobs"]) + 25 + muth = ( + 5 * jnp.log10((1 + jnp.array(data["zobs"])) * jnp.array(data["rcom_zobs"])) + 25 + ) dmu = mu - muth velocity = -redshift_dependence * dmu @@ -253,7 +260,12 @@ def compute_observed_distance_modulus( beta, M0, ): - mu = data["mb"] + alpha * data["x1"] - beta * data["c"] - M0 + mu = ( + jnp.array(data["mb"]) + + alpha * jnp.array(data["x1"]) + - beta * jnp.array(data["c"]) + - M0 + ) return mu @@ -264,12 +276,14 @@ def compute_observed_distance_modulus_error( beta, ): variance_mu = ( - data["e_mb"] ** 2 + alpha**2 * data["e_x1"] ** 2 + beta**2 * data["e_c"] ** 2 + jnp.array(data["e_mb"]) ** 2 + + alpha**2 * jnp.array(data["e_x1"]) ** 2 + + beta**2 * jnp.array(data["e_c"]) ** 2 ) variance_mu += ( - 2 * alpha * data["cov_mb_x1"] - - 2 * beta * data["cov_mb_c"] - - 2 * alpha * beta * data["cov_x1_c"] + 2 * alpha * jnp.array(data["cov_mb_x1"]) + - 2 * beta * jnp.array(data["cov_mb_c"]) + - 2 * alpha * beta * jnp.array(data["cov_x1_c"]) ) return variance_mu @@ -279,8 +293,8 @@ def redshift_dependence_velocity( parameter_values_dict, velocity_estimator, ): - prefactor = utils._C_LIGHT_KMS_ * np.log(10) / 5 - redshift_obs = data["zobs"] + prefactor = utils._C_LIGHT_KMS_ * jnp.log(10) / 5 + redshift_obs = jnp.array(data["zobs"]) if velocity_estimator == "watkins": redshift_dependence = prefactor * redshift_obs / (1 + redshift_obs) @@ -311,7 +325,7 @@ def redshift_dependence_velocity( redshift_dependence = prefactor / ( (1 + redshift_obs) * utils._C_LIGHT_KMS_ - / (data["hubble_norm"] * data["rcom_zobs"]) + / (jnp.array(data["hubble_norm"]) * jnp.array(data["rcom_zobs"])) - 1.0 ) diff --git a/notebook/fit_velocity_jax.ipynb b/notebook/fit_velocity_jax.ipynb new file mode 100644 index 0000000..ccebe1e --- /dev/null +++ b/notebook/fit_velocity_jax.ipynb @@ -0,0 +1,705 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "71e15407-b85a-4152-858f-86087f5fae7d", + "metadata": {}, + "source": [ + "This is a tutorial to use the flip package: https://github.com/corentinravoux/flip \\\n", + "It is self-contained and can be used in google collab or on your environement \\\n", + "All the data used are subsampled version of a simulation. \\\n", + "The data size is small for the tutorial, do not use it for science case. \\" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7388f2f8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found existing installation: jax 0.4.30\n", + "Uninstalling jax-0.4.30:\n", + " Successfully uninstalled jax-0.4.30\n" + ] + } + ], + "source": [ + "!pip uninstall -y jax" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "678715b7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Defaulting to user installation because normal site-packages is not writeable\n", + "Requirement already satisfied: jax in /home/ravoux/.local/lib/python3.10/site-packages (0.4.30)\n", + "Requirement already satisfied: scipy>=1.9 in /home/ravoux/.local/lib/python3.10/site-packages (from jax) (1.11.3)\n", + "Requirement already satisfied: numpy>=1.22 in /home/ravoux/.local/lib/python3.10/site-packages (from jax) (1.22.4)\n", + "Requirement already satisfied: opt-einsum in /home/ravoux/.local/lib/python3.10/site-packages (from jax) (3.3.0)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /home/ravoux/.local/lib/python3.10/site-packages (from jax) (0.4.0)\n", + "Requirement already satisfied: jaxlib<=0.4.30,>=0.4.27 in /home/ravoux/.local/lib/python3.10/site-packages (from jax) (0.4.30)\n" + ] + } + ], + "source": [ + "!pip install jax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "978011a7", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install ../\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7e1da453-ef47-4920-b330-a7d9289e2f01", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/ravoux/.local/lib/python3.10/site-packages/matplotlib/projections/__init__.py:63: UserWarning: Unable to import Axes3D. This may be due to multiple versions of Matplotlib being installed (e.g. as a system package and as a pip package). As a result, the 3D projection is not available.\n", + " warnings.warn(\"Unable to import Axes3D. This may be due to multiple versions of \"\n", + "[ 000000.00 ]: 07-26 16:58 root INFO No pypower module detected, gridding with this method is unavailable\n", + "[ 000000.00 ]: 07-26 16:58 root WARNING Install CLASS https://github.com/lesgourg/class_public to use class.py module\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from flip import fitter, plot_utils, utils, vectors\n", + "from flip.covariance import covariance, contraction\n", + "from pkg_resources import resource_filename\n", + "flip_base = resource_filename(\"flip\", \".\")\n", + "data_path = os.path.join(flip_base, \"data\")\n", + "plt.style.use(os.path.join(data_path,\"style.mplstyle\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4d8f20f1", + "metadata": {}, + "outputs": [], + "source": [ + "use_jit = True\n", + "use_gradient = True" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "36df6de3", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[ 000004.02 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.02 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.03 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.04 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.04 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.04 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.04 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.04 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.04 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.05 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.05 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.05 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.05 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.05 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000004.30 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'cuda': \n", + "[ 000004.30 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'\n", + "[ 000004.31 ]: 07-26 16:58 jax._src.xla_bridge INFO Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory\n", + "[ 000005.24 ]: 07-26 16:58 root INFO Covariance matrix generated from flip with adamsblake20 model in 1.48e+00 seconds\n", + "[ 000012.75 ]: 07-26 16:58 root INFO ┌─────────────────────────────────────────────────────────────────────────┐\n", + "│ Migrad │\n", + "├──────────────────────────────────┬──────────────────────────────────────┤\n", + "│ FCN = 1523 │ Nfcn = 42, Ngrad = 4 │\n", + "│ EDM = 5.67e-11 (Goal: 0.0002) │ time = 7.3 sec │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Valid Minimum │ Below EDM threshold (goal x 10) │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ No parameters at limit │ Below call limit │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Hesse ok │ Covariance accurate │\n", + "└──────────────────────────────────┴──────────────────────────────────────┘\n", + "┌───┬────────┬───────────┬───────────┬────────────┬────────────┬─────────┬─────────┬───────┐\n", + "│ │ Name │ Value │ Hesse Err │ Minos Err- │ Minos Err+ │ Limit- │ Limit+ │ Fixed │\n", + "├───┼────────┼───────────┼───────────┼────────────┼────────────┼─────────┼─────────┼───────┤\n", + "│ 0 │ bs8 │ 0.82 │ 0.05 │ │ │ 0 │ │ │\n", + "│ 1 │ beta_f │ 0.31 │ 0.11 │ │ │ 0 │ │ │\n", + "└───┴────────┴───────────┴───────────┴────────────┴────────────┴─────────┴─────────┴───────┘\n", + "┌────────┬─────────────────┐\n", + "│ │ bs8 beta_f │\n", + "├────────┼─────────────────┤\n", + "│ bs8 │ 0.00214 -0.0044 │\n", + "│ beta_f │ -0.0044 0.0112 │\n", + "└────────┴─────────────────┘\n" + ] + }, + { + "data": { + "text/plain": [ + "{'bs8': 0.8228999982398881, 'beta_f': 0.3086451044958614}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grid1 = pd.read_parquet(os.path.join(data_path,\"density_data.parquet\"))\n", + "grid2 = grid1.copy()\n", + "grid2[\"x\"] = grid2[\"x\"] + 200\n", + "grid3 = grid1.copy()\n", + "grid3[\"x\"] = grid1[\"x\"] - 200\n", + "grid = pd.concat([grid1])\n", + "\n", + "grid_window = pd.read_parquet(os.path.join(data_path,\"grid_window_m.parquet\"))\n", + "\n", + "coordinates_density = np.array([grid[\"ra\"], grid[\"dec\"], grid[\"rcom\"]])\n", + "data_density = {'density': np.array(grid[\"density\"]), \n", + " 'density_error': np.array(grid[\"density_err\"])}\n", + "\n", + "\n", + "ktt, ptt = np.loadtxt(os.path.join(data_path,\"power_spectrum_tt.txt\"))\n", + "kmt, pmt = np.loadtxt(os.path.join(data_path,\"power_spectrum_mt.txt\"))\n", + "kmm, pmm = np.loadtxt(os.path.join(data_path,\"power_spectrum_mm.txt\"))\n", + "\n", + "sigmag_fiducial = 3.0\n", + "\n", + "power_spectrum_dict_bias = {\"gg\": [[kmm, pmm * np.array(grid_window[\"window_mm\"])**2]]}\n", + "power_spectrum_dict = {\"gg\": [[kmm, pmm * np.array(grid_window[\"window_mm\"])**2],\n", + " [kmt, pmt * np.array(grid_window[\"window_mt\"])],\n", + " [ktt, ptt]]}\n", + "\n", + "size_batch=10_000\n", + "number_worker=8\n", + "\n", + "covariance_fit = covariance.CovMatrix.init_from_flip(\n", + " \"adamsblake20\",\n", + " \"density\",\n", + " power_spectrum_dict,\n", + " coordinates_density=coordinates_density,\n", + " size_batch=size_batch,\n", + " number_worker=number_worker,\n", + " additional_parameters_values=(sigmag_fiducial,),\n", + ")\n", + "\n", + "\n", + "likelihood_type=\"multivariate_gaussian\"\n", + "likelihood_properties = {\"inversion_method\": \"cholesky\",\n", + " \"use_jit\":use_jit,\n", + " \"use_gradient\":use_gradient}\n", + "\n", + "parameter_dict = {\"bs8\": {\"value\": 1.0,\n", + " \"limit_low\" : 0.0,\n", + " \"fixed\" : False,\n", + " },\n", + " \"beta_f\": {\"value\": 0.4,\n", + " \"limit_low\" : 0.0,\n", + " \"fixed\" : False,\n", + " },\n", + " }\n", + "\n", + "\n", + "\n", + "\n", + "minuit_fitter = fitter.FitMinuit.init_from_covariance(\n", + " covariance_fit,\n", + " data_density,\n", + " parameter_dict,\n", + " likelihood_type=likelihood_type,\n", + " likelihood_properties = likelihood_properties)\n", + "\n", + "minuit_fitter.run()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f81a2bc9-1d9e-4969-ac2d-0b5e56fd189c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "[ 000099.26 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.70e-01 seconds\n", + "[ 000102.90 ]: 07-26 15:02 root INFO ┌─────────────────────────────────────────────────────────────────────────┐\n", + "│ Migrad │\n", + "├──────────────────────────────────┬──────────────────────────────────────┤\n", + "│ FCN = 1.374e+04 │ Nfcn = 35, Ngrad = 2 │\n", + "│ EDM = 3.49e-05 (Goal: 0.0002) │ time = 3.6 sec │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Valid Minimum │ Below EDM threshold (goal x 10) │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ No parameters at limit │ Below call limit │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Hesse ok │ Covariance accurate │\n", + "└──────────────────────────────────┴──────────────────────────────────────┘\n", + "┌───┬──────┬───────────┬───────────┬────────────┬────────────┬─────────┬─────────┬───────┐\n", + "│ │ Name │ Value │ Hesse Err │ Minos Err- │ Minos Err+ │ Limit- │ Limit+ │ Fixed │\n", + "├───┼──────┼───────────┼───────────┼────────────┼────────────┼─────────┼─────────┼───────┤\n", + "│ 0 │ fs8 │ 0.353 │ 0.025 │ │ │ 0 │ │ │\n", + "│ 1 │ sigv │ 200 │ 6 │ │ │ 0 │ │ │\n", + "└───┴──────┴───────────┴───────────┴────────────┴────────────┴─────────┴─────────┴───────┘\n", + "┌──────┬───────────────────┐\n", + "│ │ fs8 sigv │\n", + "├──────┼───────────────────┤\n", + "│ fs8 │ 0.000619 -40.7e-3 │\n", + "│ sigv │ -40.7e-3 30.3 │\n", + "└──────┴───────────────────┘\n" + ] + }, + { + "data": { + "text/plain": [ + "{'fs8': 0.3527380146365558, 'sigv': 199.53910751729765}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_velocity = pd.read_parquet(os.path.join(data_path, \"velocity_data.parquet\"))\n", + "\n", + "data_velocity_true = {\"velocity\": data_velocity[\"vpec\"].values, \n", + " \"velocity_error\": np.zeros(len(data_velocity[\"vpec\"].values))}\n", + "\n", + "coordinates_velocity = np.array([data_velocity[\"ra\"], \n", + " data_velocity[\"dec\"], \n", + " data_velocity[\"rcom_zobs\"]])\n", + "\n", + "ktt, ptt = np.loadtxt(os.path.join(data_path, \"power_spectrum_tt.txt\"))\n", + "kmt, pmt = np.loadtxt(os.path.join(data_path, \"power_spectrum_mt.txt\"))\n", + "kmm, pmm = np.loadtxt(os.path.join(data_path, \"power_spectrum_mm.txt\"))\n", + "\n", + "sigmau_fiducial = 15\n", + "\n", + "power_spectrum_dict = {\"vv\": [[ktt, ptt * utils.Du(ktt, sigmau_fiducial) ** 2]]}\n", + "\n", + "size_batch=10_000\n", + "number_worker=16\n", + "\n", + "\n", + "covariance_fit = covariance.CovMatrix.init_from_flip(\n", + " \"carreres23\",\n", + " \"velocity\",\n", + " power_spectrum_dict,\n", + " coordinates_velocity=coordinates_velocity,\n", + " size_batch=size_batch,\n", + " number_worker=number_worker,\n", + ")\n", + "\n", + "likelihood_type=\"multivariate_gaussian\"\n", + "likelihood_properties = {\"inversion_method\": \"cholesky\",\n", + " \"velocity_type\": \"direct\",\n", + " \"use_jit\": use_jit,\n", + " \"use_gradient\": use_gradient}\n", + "\n", + "\n", + "parameter_dict = {\"fs8\": {\"value\": 0.4,\n", + " \"limit_low\" : 0.0,\n", + " \"fixed\" : False,\n", + " },\n", + " \"sigv\": {\"value\": 200,\n", + " \"limit_low\" : 0.0,\n", + " \"fixed\" : False,\n", + " },\n", + " }\n", + "\n", + "\n", + "\n", + "minuit_fitter = fitter.FitMinuit.init_from_covariance(\n", + " covariance_fit,\n", + " data_velocity_true,\n", + " parameter_dict,\n", + " likelihood_type=likelihood_type,\n", + " likelihood_properties = likelihood_properties\n", + ")\n", + "\n", + "\n", + "minuit_fitter.run()\n", + "\n", + "# parameter_values = [\n", + "# parameter_dict[parameters][\"value\"] for parameters in parameter_dict\n", + "# ]\n", + "# minuit_fitter.likelihood(parameter_values)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e6914dc3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "[ 000108.42 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.67e-01 seconds\n", + "[ 000109.02 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.99e-01 seconds\n", + "[ 000109.61 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.86e-01 seconds\n", + "[ 000110.21 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 6.00e-01 seconds\n", + "[ 000110.81 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 6.01e-01 seconds\n", + "[ 000111.41 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.99e-01 seconds\n", + "[ 000112.00 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.89e-01 seconds\n", + "[ 000112.60 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 5.96e-01 seconds\n", + "[ 000113.21 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 6.10e-01 seconds\n", + "[ 000113.81 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 6.00e-01 seconds\n", + "[ 000122.60 ]: 07-26 15:02 root INFO ┌─────────────────────────────────────────────────────────────────────────┐\n", + "│ Migrad │\n", + "├──────────────────────────────────┬──────────────────────────────────────┤\n", + "│ FCN = 1.374e+04 │ Nfcn = 47, Ngrad = 3 │\n", + "│ EDM = 0.000171 (Goal: 0.0002) │ time = 8.1 sec │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Valid Minimum │ Below EDM threshold (goal x 10) │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ SOME parameters at limit │ Below call limit │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Hesse ok │ Covariance accurate │\n", + "└──────────────────────────────────┴──────────────────────────────────────┘\n", + "┌───┬──────┬───────────┬───────────┬────────────┬────────────┬─────────┬─────────┬───────┐\n", + "│ │ Name │ Value │ Hesse Err │ Minos Err- │ Minos Err+ │ Limit- │ Limit+ │ Fixed │\n", + "├───┼──────┼───────────┼───────────┼────────────┼────────────┼─────────┼─────────┼───────┤\n", + "│ 0 │ fs8 │ 0.366 │ 0.027 │ │ │ 0 │ │ │\n", + "│ 1 │ sigv │ 201 │ 5 │ │ │ 0 │ │ │\n", + "│ 2 │ sigu │ 17.0 │ 2.9 │ │ │ 13 │ 17 │ │\n", + "└───┴──────┴───────────┴───────────┴────────────┴────────────┴─────────┴─────────┴───────┘\n", + "┌──────┬────────────────────────────┐\n", + "│ │ fs8 sigv sigu │\n", + "├──────┼────────────────────────────┤\n", + "│ fs8 │ 0.000704 -40.8e-3 0.1e-3 │\n", + "│ sigv │ -40.8e-3 29.7 0.0052 │\n", + "│ sigu │ 0.1e-3 0.0052 0.00245 │\n", + "└──────┴────────────────────────────┘\n" + ] + }, + { + "data": { + "text/plain": [ + "{'fs8': 0.36614482471524257,\n", + " 'sigv': 200.8121744175942,\n", + " 'sigu': 16.99972091553989}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_velocity = pd.read_parquet(os.path.join(data_path, \"velocity_data.parquet\"))\n", + "\n", + "data_velocity_true = {\"velocity\": data_velocity[\"vpec\"].values, \n", + " \"velocity_error\": np.zeros(len(data_velocity[\"vpec\"].values))}\n", + "\n", + "coordinates_velocity = np.array([data_velocity[\"ra\"], \n", + " data_velocity[\"dec\"], \n", + " data_velocity[\"rcom_zobs\"]])\n", + "\n", + "ktt, ptt = np.loadtxt(os.path.join(data_path, \"power_spectrum_tt.txt\"))\n", + "kmt, pmt = np.loadtxt(os.path.join(data_path, \"power_spectrum_mt.txt\"))\n", + "kmm, pmm = np.loadtxt(os.path.join(data_path, \"power_spectrum_mm.txt\"))\n", + "\n", + "sigmau_fiducial = 15\n", + "\n", + "power_spectrum_dict = {\"vv\": [[ktt, ptt * utils.Du(ktt, sigmau_fiducial) ** 2]]}\n", + "\n", + "size_batch=10_000\n", + "number_worker=16\n", + "\n", + "sigmau_list = np.linspace(10.0, 20.0,10)\n", + "covariance_list = []\n", + "size_batch=10_000\n", + "number_worker=16\n", + "\n", + "\n", + "for sigu in sigmau_list:\n", + " \n", + " power_spectrum_dict = {\"vv\": [[ktt, ptt* utils.Du(ktt, sigu)**2]]}\n", + "\n", + " covariance_list.append(covariance.CovMatrix.init_from_flip(\n", + " \"carreres23\",\n", + " \"velocity\",\n", + " power_spectrum_dict,\n", + " coordinates_velocity=coordinates_velocity,\n", + " size_batch=size_batch,\n", + " number_worker=number_worker,\n", + " ))\n", + " \n", + "likelihood_type=\"multivariate_gaussian_interp1d\"\n", + "likelihood_properties = {\"inversion_method\": \"cholesky\",\n", + " \"velocity_type\": \"direct\",\n", + " \"use_jit\": use_jit,\n", + " \"use_gradient\": use_gradient,}\n", + "\n", + "parameter_dict = {\"fs8\": {\"value\": 0.4,\n", + " \"limit_low\" : 0.0,\n", + " \"fixed\" : False,\n", + " },\n", + " \"sigv\": {\"value\": 200,\n", + " \"limit_low\" : 0.0,\n", + " \"fixed\" : False,\n", + " },\n", + " \"sigu\": {\"value\": 15.0,\n", + " \"limit_low\" : 13.0,\n", + " \"limit_up\" : 17.0,\n", + " \"fixed\" : False,\n", + " },\n", + " }\n", + "\n", + "\n", + "minuit_fitter_interp = fitter.FitMinuit.init_from_covariance(\n", + " covariance_list,\n", + " data_velocity_true,\n", + " parameter_dict,\n", + " likelihood_type=likelihood_type,\n", + " likelihood_properties = likelihood_properties,\n", + " interpolation_value_name = \"sigu\",\n", + " interpolation_value_range = sigmau_list,\n", + ")\n", + "\n", + "\n", + "minuit_fitter_interp.run()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0ce455be-c974-467b-923a-b23b1b1f95d1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " self.pid = os.fork()\n", + "[ 000127.40 ]: 07-26 15:02 root INFO Covariance matrix generated from flip with carreres23 model in 6.18e-01 seconds\n", + "/home/ravoux/.local/lib/python3.10/site-packages/iminuit/util.py:172: IMinuitWarning: Assigned errors must be positive. Non-positive values are replaced by a heuristic.\n", + " warnings.warn(\n", + "[ 000136.93 ]: 07-26 15:02 root INFO ┌─────────────────────────────────────────────────────────────────────────┐\n", + "│ Migrad │\n", + "├──────────────────────────────────┬──────────────────────────────────────┤\n", + "│ FCN = 1.415e+04 │ Nfcn = 139, Ngrad = 8 │\n", + "│ EDM = 1.38e-05 (Goal: 0.0002) │ time = 9.5 sec │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Valid Minimum │ Below EDM threshold (goal x 10) │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ No parameters at limit │ Below call limit │\n", + "├──────────────────────────────────┼──────────────────────────────────────┤\n", + "│ Hesse ok │ Covariance accurate │\n", + "└──────────────────────────────────┴──────────────────────────────────────┘\n", + "┌───┬─────────┬───────────┬───────────┬────────────┬────────────┬─────────┬─────────┬───────┐\n", + "│ │ Name │ Value │ Hesse Err │ Minos Err- │ Minos Err+ │ Limit- │ Limit+ │ Fixed │\n", + "├───┼─────────┼───────────┼───────────┼────────────┼────────────┼─────────┼─────────┼───────┤\n", + "│ 0 │ fs8 │ 0.47 │ 0.12 │ │ │ 0 │ 1 │ │\n", + "│ 1 │ sigv │ 590 │ 90 │ │ │ 0 │ 1000 │ │\n", + "│ 2 │ alpha │ 0.133 │ 0.004 │ │ │ 0 │ 0.25 │ │\n", + "│ 3 │ beta │ 3.10 │ 0.04 │ │ │ 2.5 │ 4.5 │ │\n", + "│ 4 │ M_0 │ -19.068 │ 0.010 │ │ │ -21 │ -18 │ │\n", + "│ 5 │ sigma_M │ 0.086 │ 0.007 │ │ │ 0 │ 100 │ │\n", + "└───┴─────────┴───────────┴───────────┴────────────┴────────────┴─────────┴─────────┴───────┘\n", + "┌─────────┬───────────────────────────────────────────────────────────────────┐\n", + "│ │ fs8 sigv alpha beta M_0 sigma_M │\n", + "├─────────┼───────────────────────────────────────────────────────────────────┤\n", + "│ fs8 │ 0.0145 -1.812 -0.011e-3 -0.0001 -0.19e-3 0.05e-3 │\n", + "│ sigv │ -1.812 8.89e+03 3.938e-3 0.0023 5.53e-3 -566.95e-3 │\n", + "│ alpha │ -0.011e-3 3.938e-3 1.79e-05 -0.001e-3 0 -0.001e-3 │\n", + "│ beta │ -0.0001 0.0023 -0.001e-3 0.0016 -0.08e-3 -0.01e-3 │\n", + "│ M_0 │ -0.19e-3 5.53e-3 0 -0.08e-3 0.000107 0 │\n", + "│ sigma_M │ 0.05e-3 -566.95e-3 -0.001e-3 -0.01e-3 0 4.88e-05 │\n", + "└─────────┴───────────────────────────────────────────────────────────────────┘\n" + ] + }, + { + "data": { + "text/plain": [ + "{'fs8': 0.47381309968970364,\n", + " 'sigv': 593.1049535697847,\n", + " 'alpha': 0.13285894991736366,\n", + " 'beta': 3.099666399471431,\n", + " 'M_0': -19.067819495828125,\n", + " 'sigma_M': 0.0863770652902524}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_parameters = {\"alpha\":0.14,\n", + " \"beta\": 3.1,\n", + " \"M_0\": -19.133,\n", + " \"sigma_M\": 0.12}\n", + "\n", + "velocity, velocity_error = vectors.get_velocity_from_salt_fit(\n", + " data_velocity,\n", + " test_parameters,\n", + " \"full\",\n", + ")\n", + "\n", + "mask = np.abs(velocity) < 3_000\n", + "mask &= velocity_error < 3_000\n", + "\n", + "coordinates_velocity_cleaned = np.transpose(np.transpose(coordinates_velocity)[mask])\n", + "data_velocity_cleaned = {}\n", + "for key in list(data_velocity.keys()):\n", + " data_velocity_cleaned[key] = np.array(data_velocity[key])[mask]\n", + " \n", + "size_batch=10_000\n", + "number_worker=16\n", + "\n", + "\n", + "covariance_fit = covariance.CovMatrix.init_from_flip(\n", + " \"carreres23\",\n", + " \"velocity\",\n", + " power_spectrum_dict,\n", + " coordinates_velocity=coordinates_velocity_cleaned,\n", + " size_batch=size_batch,\n", + " number_worker=number_worker,\n", + ")\n", + "\n", + "\n", + "likelihood_type=\"multivariate_gaussian\"\n", + "likelihood_properties = {\"inversion_method\": \"cholesky\",\n", + " \"velocity_type\": \"saltfit\",\n", + " \"velocity_estimator\": \"full\",\n", + " \"use_jit\": use_jit,\n", + " \"use_gradient\": use_gradient,\n", + " }\n", + "\n", + "\n", + "parameter_dict = {\"fs8\": {\"value\": 0.4,\n", + " \"limit_low\" : 0.0,\n", + " \"limit_up\" : 1.0,\n", + " \"fixed\" : False,\n", + " },\n", + " \"sigv\": {\"value\": 200,\n", + " \"limit_low\" : 0.0,\n", + " \"limit_up\" : 1000,\n", + " \"fixed\" : False,\n", + " },\n", + " \"alpha\": {\"value\": 0.1,\n", + " \"limit_low\" : 0.0,\n", + " \"limit_up\" : 0.25,\n", + " \"fixed\" : False,\n", + " },\n", + " \"beta\": {\"value\": 3.0,\n", + " \"limit_low\" : 2.5,\n", + " \"limit_up\" : 4.5,\n", + " \"fixed\" : False,\n", + " },\n", + " \"M_0\": {\"value\": -19,\n", + " \"limit_low\" : -21,\n", + " \"limit_up\" : -18,\n", + " \"fixed\" : False,\n", + " },\n", + " \"sigma_M\": {\"value\": 0.1,\n", + " \"limit_low\" : 0.0,\n", + " \"limit_up\" : 100.0,\n", + " \"fixed\" : False,\n", + " }}\n", + "\n", + "minuit_fitter_velest = fitter.FitMinuit.init_from_covariance(\n", + " covariance_fit,\n", + " data_velocity_cleaned,\n", + " parameter_dict,\n", + " likelihood_type=likelihood_type,\n", + " likelihood_properties = likelihood_properties\n", + ")\n", + "\n", + "\n", + "minuit_fitter_velest.run()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56a21e83", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}