diff --git a/.github/workflows/bumpversion.yml b/.github/workflows/bumpversion.yml index dd528ac..16eefb4 100644 --- a/.github/workflows/bumpversion.yml +++ b/.github/workflows/bumpversion.yml @@ -34,7 +34,7 @@ jobs: run: bump2version patch --verbose - name: Bump major version - if: ${{ ontains(github.event.head_commit.message, '[bump major]') }} + if: ${{contains(github.event.head_commit.message, '[bump major]') }} run: bump2version major --verbose --tag - name: Bump minor version diff --git a/flip/data_vector/basic.py b/flip/data_vector/basic.py index 7efa133..d90dd98 100644 --- a/flip/data_vector/basic.py +++ b/flip/data_vector/basic.py @@ -43,8 +43,11 @@ class DataVector(abc.ABC): _kind (str): One of "velocity", "density" or "cross". """ - _free_par = [] _kind = "" # 'velocity', 'density' or 'cross' + _needed_keys = [] + _free_par = [] + _number_dimension_observation_covariance = 0 + _parameters_observation_covariance = [] @property def conditional_free_par(self): @@ -109,16 +112,6 @@ def give_data_and_variance(self, **kwargs): """ pass - def _check_keys(self, data): - """Validate that `data` contains all required keys. - - Raises: - ValueError: When a required key is missing. - """ - for k in self.needed_keys: - if k not in data: - raise ValueError(f"{k} field is needed in data") - def __init__(self, data, covariance_observation=None, **kwargs): """Initialize data vector with data and optional observation covariance. @@ -129,6 +122,8 @@ def __init__(self, data, covariance_observation=None, **kwargs): """ self._covariance_observation = covariance_observation self._check_keys(data) + self._number_datapoints = len(data[self.needed_keys[0]]) + self.check_covariance_observation() self._data = copy.copy(data) self._kwargs = kwargs @@ -138,6 +133,31 @@ def __init__(self, data, covariance_observation=None, **kwargs): if jax_installed: self.give_data_and_variance_jit = jit(self.give_data_and_variance) + def check_covariance_observation(self): + if self._covariance_observation is not None: + if self._covariance_observation.shape != ( + self._number_dimension_observation_covariance * self._number_datapoints, + self._number_dimension_observation_covariance * self._number_datapoints, + ): + raise ValueError( + f"Observation covariance matrix should be {self._number_dimension_observation_covariance}N " + f"x {self._number_dimension_observation_covariance}N" + ) + log.add( + f"Loading observation covariance matrix, " + f"expecting {self._parameters_observation_covariance} parameters." + ) + + def _check_keys(self, data): + """Validate that `data` contains all required keys. + + Raises: + ValueError: When a required key is missing. + """ + for k in self.needed_keys: + if k not in data: + raise ValueError(f"{k} field is needed in data") + def get_masked_data_and_cov(self, bool_mask): """Return masked data and corresponding masked observation covariance. @@ -189,6 +209,9 @@ def compute_covariance(self, model, power_spectrum_dict, **kwargs): class Dens(DataVector): _kind = "density" _needed_keys = ["density", "density_error"] + _free_par = [] + _number_dimension_observation_covariance = 1 + _parameters_observation_covariance = ["density"] def give_data_and_variance(self, *args): """Return density data and diagonal variance from `density_error`. @@ -196,12 +219,21 @@ def give_data_and_variance(self, *args): Returns: tuple: (density, density_error^2). """ + + if self._covariance_observation is not None: + return self._data["density"], self._covariance_observation return self._data["density"], self._data["density_error"] ** 2 + def __init__(self, data, covariance_observation=None): + super().__init__(data, covariance_observation=covariance_observation) + class DirectVel(DataVector): _kind = "velocity" _needed_keys = ["velocity"] + _free_par = [] + _number_dimension_observation_covariance = 1 + _parameters_observation_covariance = ["velocity"] @property def conditional_needed_keys(self): @@ -253,66 +285,12 @@ def __init__(self, data, covariance_observation=None): self._covariance_observation = velocity_variance -class DensVel(DataVector): - _kind = "cross" - - @property - def needed_keys(self): - return self.densities.needed_keys + self.velocities.needed_keys - - @property - def free_par(self): - return self.densities.free_par + self.velocities.free_par - - def give_data_and_variance(self, *args): - data_density, density_variance = self.densities.give_data_and_variance(*args) - data_velocity, velocity_variance = self.velocities.give_data_and_variance(*args) - data = jnp.hstack((data_density, data_velocity)) - variance = jnp.hstack((density_variance, velocity_variance)) - return data, variance - - def __init__(self, density_vector, velocity_vector): - self.densities = density_vector - self.velocities = velocity_vector - - if self.velocities._covariance_observation is not None: - raise NotImplementedError( - "Velocity with observed covariance + density not implemented yet" - ) - - if jax_installed: - self.give_data_and_variance_jit = jit(self.give_data_and_variance) - - def compute_covariance(self, model, power_spectrum_dict, **kwargs): - - coords_dens = np.vstack( - ( - self.densities.data["ra"], - self.densities.data["dec"], - self.densities.data["rcom_zobs"], - ) - ) - - coords_vel = np.vstack( - ( - self.velocities.data["ra"], - self.velocities.data["dec"], - self.velocities.data["rcom_zobs"], - ) - ) - return CovMatrix.init_from_flip( - model, - "full", - power_spectrum_dict, - coordinates_density=coords_dens, - coordinates_velocity=coords_vel, - **kwargs, - ) - - class VelFromHDres(DataVector): + _kind = "velocity" _needed_keys = ["dmu", "zobs"] _free_par = ["M_0"] + _number_dimension_observation_covariance = 1 + _parameters_observation_covariance = ["dmu"] @property def conditional_needed_keys(self): @@ -331,25 +309,23 @@ def give_data_and_variance(self, parameter_values_dict): distance_modulus_difference_to_velocity * self._data["dmu"] - distance_modulus_difference_to_velocity * parameter_values_dict["M_0"] ) - if self._covariance_observation is None and "dmu_error" in self._data: - velocity_error = ( - distance_modulus_difference_to_velocity * self._data["dmu_error"] - ) - return velocity, velocity_error**2 - elif self._covariance_observation is not None: - J = jnp.diag(self._distance_modulus_difference_to_velocity) - velocity_variance = J @ self._covariance_observation @ J.T - return velocity, velocity_variance + if self._covariance_observation is None: + velocity_variance = ( + distance_modulus_difference_to_velocity * self._data["dmu_error"] + ) ** 2 else: - raise ValueError( - "Cannot compute velocity variance without dmu_error or covariance_observation" + conversion_matrix = jnp.diag(distance_modulus_difference_to_velocity) + + velocity_variance = ( + conversion_matrix @ self._covariance_observation @ conversion_matrix.T ) + return velocity, velocity_variance + def __init__( self, data, covariance_observation=None, velocity_estimator="full", **kwargs ): - # Compute conversion using provided input data, not uninitialized self._data self.velocity_estimator = velocity_estimator @@ -360,6 +336,8 @@ class VelFromIntrinsicScatter(DataVector): _kind = "velocity" _needed_keys = ["zobs"] _free_par = ["sigma_M"] + _number_dimension_observation_covariance = 0 + _parameters_observation_covariance = [] def give_data_and_variance(self, parameter_values_dict): distance_modulus_difference_to_velocity = ( @@ -389,3 +367,60 @@ def give_data_and_variance(self, parameter_values_dict): def __init__(self, data, velocity_estimator="full"): super().__init__(data) self.velocity_estimator = velocity_estimator + + +class DensVel(DataVector): + _kind = "cross" + + @property + def needed_keys(self): + return self.densities.needed_keys + self.velocities.needed_keys + + @property + def free_par(self): + return self.densities.free_par + self.velocities.free_par + + def give_data_and_variance(self, *args): + data_density, density_variance = self.densities.give_data_and_variance(*args) + data_velocity, velocity_variance = self.velocities.give_data_and_variance(*args) + data = jnp.hstack((data_density, data_velocity)) + variance = jnp.hstack((density_variance, velocity_variance)) + return data, variance + + def __init__(self, density_vector, velocity_vector): + self.densities = density_vector + self.velocities = velocity_vector + + if self.velocities._covariance_observation is not None: + raise NotImplementedError( + "Velocity with observed covariance + density not implemented yet" + ) + + if jax_installed: + self.give_data_and_variance_jit = jit(self.give_data_and_variance) + + def compute_covariance(self, model, power_spectrum_dict, **kwargs): + + coords_dens = np.vstack( + ( + self.densities.data["ra"], + self.densities.data["dec"], + self.densities.data["rcom_zobs"], + ) + ) + + coords_vel = np.vstack( + ( + self.velocities.data["ra"], + self.velocities.data["dec"], + self.velocities.data["rcom_zobs"], + ) + ) + return CovMatrix.init_from_flip( + model, + "full", + power_spectrum_dict, + coordinates_density=coords_dens, + coordinates_velocity=coords_vel, + **kwargs, + ) diff --git a/flip/data_vector/galaxypv_vectors.py b/flip/data_vector/galaxypv_vectors.py index ed7eee7..d0487d9 100644 --- a/flip/data_vector/galaxypv_vectors.py +++ b/flip/data_vector/galaxypv_vectors.py @@ -1,3 +1,5 @@ +from flip.utils import create_log + from .._config import __use_jax__ from . import vector_utils from .basic import DataVector @@ -19,10 +21,15 @@ jax_installed = False +log = create_log() + class VelFromLogDist(DataVector): _kind = "velocity" _needed_keys = ["eta"] + _free_par = [] + _number_dimension_observation_covariance = 1 + _parameters_observation_covariance = ["eta"] @property def conditional_needed_keys(self): @@ -36,6 +43,22 @@ def conditional_needed_keys(self): cond_keys += ["eta_error"] return self._needed_keys + cond_keys + def __init__( + self, + data, + covariance_observation=None, + velocity_estimator="full", + ): + """Initialize velocity from log-distance `eta`. + + Args: + data (dict): Must include `eta` and optionally `eta_error`. + covariance_observation (ndarray|None): Observed covariance. + velocity_estimator (str): Estimator name, default `"full"`. + """ + self.velocity_estimator = velocity_estimator + super().__init__(data, covariance_observation=covariance_observation) + def give_data_and_variance(self, parameter_values_dict, *args): """Return velocity and variance for log-distance based estimator. @@ -53,31 +76,19 @@ def give_data_and_variance(self, parameter_values_dict, *args): velocity = log_distance_to_velocity * self._data["eta"] - if self._covariance_observation is not None: - J = jnp.diag(log_distance_to_velocity) - velocity_variance = J @ self._covariance_observation @ J.T - return velocity, velocity_variance + if self._covariance_observation is None: + velocity_variance = ( + log_distance_to_velocity * self._data["eta_error"] + ) ** 2 - return velocity, (log_distance_to_velocity * self._data["eta_error"]) ** 2 + else: + conversion_matrix = jnp.diag(log_distance_to_velocity) - def __init__( - self, - data, - covariance_observation=None, - velocity_estimator="full", - ): - """Initialize velocity from log-distance `eta`. + velocity_variance = ( + conversion_matrix @ self._covariance_observation @ conversion_matrix.T + ) - Args: - data (dict): Must include `eta` and optionally `eta_error`. - covariance_observation (ndarray|None): Observed covariance. - velocity_estimator (str): Estimator name, default `"full"`. - """ - self.velocity_estimator = velocity_estimator - super().__init__( - data, - covariance_observation=covariance_observation, - ) + return velocity, velocity_variance class VelFromTullyFisher(DataVector): @@ -85,6 +96,8 @@ class VelFromTullyFisher(DataVector): _kind = "velocity" _needed_keys = ["zobs", "logW", "m_mean", "rcom_zobs"] _free_par = ["a", "b"] + _number_dimension_observation_covariance = 2 + _parameters_observation_covariance = ["logW", "m_mean"] @property def conditional_needed_keys(self): @@ -98,6 +111,39 @@ def conditional_needed_keys(self): cond_keys += ["e_logW", "e_m_mean"] return cond_keys + def __init__( + self, + data, + h, + covariance_observation=None, + velocity_estimator="full", + ): + """Initialize Tully–Fisher velocity vector. + + Args: + data (dict): Includes `logW`, `m_mean`, redshifts and distances. + h (float): Little-h scaling for distances. + covariance_observation (ndarray|None): Optional observation covariance. + velocity_estimator (str): Estimator name. + + Raises: + ValueError: If covariance shape is not `2N x 2N` when provided. + """ + super().__init__(data, covariance_observation=covariance_observation) + self.velocity_estimator = velocity_estimator + self.h = h + self._host_matrix = None + + if "host_group_id" in data: + self._host_matrix, self._data_to_group_mapping = ( + vector_utils.compute_host_matrix(self._data["host_group_id"]) + ) + self._data = vector_utils.format_data_multiple_host( + self._data, self._host_matrix + ) + if jax_installed: + self._host_matrix = BCOO.from_scipy_sparse(self._host_matrix) + def compute_observed_distance_modulus(self, parameter_values_dict): """Compute observed distance modulus from Tully–Fisher relation. @@ -159,11 +205,23 @@ def compute_observed_distance_modulus_variance( ) variance_distance_modulus += parameter_values_dict["sigma_M"] ** 2 else: + weights_observation_covariance = jnp.array( + [ + 1.0, + parameter_values_dict["a"], + ] + ) + jacobian = jnp.kron( + weights_observation_covariance, + jnp.eye(self._number_datapoints), + ) variance_distance_modulus = ( - self._covariance_observation - + jnp.eye(self._covariance_observation.shape[0]) - * parameter_values_dict["sigma_M"] ** 2 + jacobian @ self._covariance_observation @ jacobian.T + ) + variance_distance_modulus += ( + jnp.eye(self._number_datapoints) * parameter_values_dict["sigma_M"] ** 2 ) + return variance_distance_modulus def give_data_and_variance(self, parameter_values_dict): @@ -190,10 +248,13 @@ def give_data_and_variance(self, parameter_values_dict): * distance_modulus_difference_to_velocity**2 ) else: - A = self._init_A() - J = A[0] + parameter_values_dict["a"] * A[1] - J = jnp.diag(distance_modulus_difference_to_velocity) @ J - velocity_variance = J @ observed_distance_modulus_variance @ J.T + conversion_matrix = jnp.diag(distance_modulus_difference_to_velocity) + + velocity_variance = ( + conversion_matrix + @ observed_distance_modulus_variance + @ conversion_matrix.T + ) velocities = ( distance_modulus_difference_to_velocity @@ -207,18 +268,25 @@ def give_data_and_variance(self, parameter_values_dict): return velocities, velocity_variance - def _init_A(self): - """Initialize design matrices for linear propagation with covariance. + +class VelFromFundamentalPlane(DataVector): + _kind = "velocity" + _needed_keys = ["zobs", "logRe", "logsig", "logI", "rcom_zobs"] + _free_par = ["a", "b", "c"] + _number_dimension_observation_covariance = 3 + _parameters_observation_covariance = ["logRe", "logsig", "logI"] + + @property + def conditional_needed_keys(self): + """Conditionally required keys when covariance is absent. Returns: - ndarray: Matrix A blocks. + list[str]: Includes `e_logRe`, `e_logsig`, `e_logI` when needed. """ - N = len(self._data) - A = jnp.ones((2, N, 2 * N)) - ij = jnp.ogrid[:N, : 2 * N] - for k in range(2): - A[k][ij[1] == 2 * ij[0] + k] = 1 - return A + cond_keys = [] + if self._covariance_observation is None: + cond_keys += ["e_logRe", "e_logsig", "e_logI"] + return cond_keys def __init__( self, @@ -227,21 +295,20 @@ def __init__( covariance_observation=None, velocity_estimator="full", ): - """Initialize Tully–Fisher velocity vector. + """Initialize Fundamental Plane velocity vector. Args: - data (dict): Includes `logW`, `m_mean`, redshifts and distances. + data (dict): Includes `logRe`, `logsig`, `logI`, redshifts and distances. h (float): Little-h scaling for distances. covariance_observation (ndarray|None): Optional observation covariance. velocity_estimator (str): Estimator name. Raises: - ValueError: If covariance shape is not `2N x 2N` when provided. + ValueError: If covariance shape is not `3N x 3N` when provided. """ super().__init__(data, covariance_observation=covariance_observation) self.velocity_estimator = velocity_estimator self.h = h - self._A = None self._host_matrix = None if "host_group_id" in data: @@ -254,29 +321,6 @@ def __init__( if jax_installed: self._host_matrix = BCOO.from_scipy_sparse(self._host_matrix) - if self._covariance_observation is not None: - if self._covariance_observation.shape != (2 * len(data), 2 * len(data)): - raise ValueError("Cov should be 2N x 2N") - - -class VelFromFundamentalPlane(DataVector): - - _kind = "velocity" - _needed_keys = ["zobs", "logRe", "logsig", "logI", "rcom_zobs"] - _free_par = ["a", "b", "c"] - - @property - def conditional_needed_keys(self): - """Conditionally required keys when covariance is absent. - - Returns: - list[str]: Includes `e_logRe`, `e_logsig`, `e_logI` when needed. - """ - cond_keys = [] - if self._covariance_observation is None: - cond_keys += ["e_logRe", "e_logsig", "e_logI"] - return cond_keys - def compute_observed_distance_modulus(self, parameter_values_dict): """Compute observed distance modulus from Fundamental Plane relation. @@ -340,11 +384,24 @@ def compute_observed_distance_modulus_variance( ) variance_distance_modulus += parameter_values_dict["sigma_M"] ** 2 else: + weights_observation_covariance = jnp.array( + [ + 1.0, + parameter_values_dict["a"], + parameter_values_dict["b"], + ] + ) + jacobian = jnp.kron( + weights_observation_covariance, + jnp.eye(self._number_datapoints), + ) variance_distance_modulus = ( - self._covariance_observation - + jnp.eye(self._covariance_observation.shape[0]) - * parameter_values_dict["sigma_M"] ** 2 + jacobian @ self._covariance_observation @ jacobian.T ) + variance_distance_modulus += ( + jnp.eye(self._number_datapoints) * parameter_values_dict["sigma_M"] ** 2 + ) + return variance_distance_modulus def give_data_and_variance(self, parameter_values_dict): @@ -371,14 +428,13 @@ def give_data_and_variance(self, parameter_values_dict): * distance_modulus_difference_to_velocity**2 ) else: - A = self._init_A() - J = ( - A[0] - + parameter_values_dict["a"] * A[1] - + parameter_values_dict["b"] * A[2] + conversion_matrix = jnp.diag(distance_modulus_difference_to_velocity) + + velocity_variance = ( + conversion_matrix + @ observed_distance_modulus_variance + @ conversion_matrix.T ) - J = jnp.diag(distance_modulus_difference_to_velocity) @ J - velocity_variance = J @ observed_distance_modulus_variance @ J.T velocities = ( distance_modulus_difference_to_velocity @@ -391,54 +447,3 @@ def give_data_and_variance(self, parameter_values_dict): ) return velocities, velocity_variance - - def _init_A(self): - """Initialize design matrices for linear propagation with covariance. - - Returns: - ndarray: Matrix A blocks. - """ - N = len(self._data) - A = jnp.ones((3, N, 3 * N)) - ij = jnp.ogrid[:N, : 3 * N] - for k in range(3): - A[k][ij[1] == 3 * ij[0] + k] = 1 - return A - - def __init__( - self, - data, - h, - covariance_observation=None, - velocity_estimator="full", - ): - """Initialize Fundamental Plane velocity vector. - - Args: - data (dict): Includes `logRe`, `logsig`, `logI`, redshifts and distances. - h (float): Little-h scaling for distances. - covariance_observation (ndarray|None): Optional observation covariance. - velocity_estimator (str): Estimator name. - - Raises: - ValueError: If covariance shape is not `3N x 3N` when provided. - """ - super().__init__(data, covariance_observation=covariance_observation) - self.velocity_estimator = velocity_estimator - self.h = h - self._A = None - self._host_matrix = None - - if "host_group_id" in data: - self._host_matrix, self._data_to_group_mapping = ( - vector_utils.compute_host_matrix(self._data["host_group_id"]) - ) - self._data = vector_utils.format_data_multiple_host( - self._data, self._host_matrix - ) - if jax_installed: - self._host_matrix = BCOO.from_scipy_sparse(self._host_matrix) - - if self._covariance_observation is not None: - if self._covariance_observation.shape != (3 * len(data), 3 * len(data)): - raise ValueError("Cov should be 3N x 3N") diff --git a/flip/data_vector/snia_vectors.py b/flip/data_vector/snia_vectors.py index f428c05..ad4df21 100644 --- a/flip/data_vector/snia_vectors.py +++ b/flip/data_vector/snia_vectors.py @@ -1,3 +1,5 @@ +from flip.utils import create_log + from .._config import __use_jax__ from . import vector_utils from .basic import DataVector @@ -19,11 +21,15 @@ jax_installed = False +log = create_log() + class VelFromSALTfit(DataVector): _kind = "velocity" _needed_keys = ["zobs", "mb", "x1", "c", "rcom_zobs"] _free_par = ["alpha", "beta", "M_0", "sigma_M"] + _number_dimension_observation_covariance = 3 + _parameters_observation_covariance = ["mb", "x1", "c"] @property def conditional_needed_keys(self): @@ -49,6 +55,42 @@ def conditional_free_par(self): _cond_fpar += ["gamma"] return _cond_fpar + def __init__( + self, + data, + h, + covariance_observation=None, + velocity_estimator="full", + mass_step=10, + ): + """Initialize SN Ia velocity vector from SALT2 fits. + + Args: + data (dict): Includes SALT2 parameters and cosmology fields. + h (float): Little-h scaling for distances. + covariance_observation (ndarray|None): Optional observation covariance. + velocity_estimator (str): Estimator name. + mass_step (float): Threshold for host mass step correction. + + Raises: + ValueError: If covariance shape is not adapted + """ + super().__init__(data, covariance_observation=covariance_observation) + self.velocity_estimator = velocity_estimator + self.h = h + self._host_matrix = None + self._mass_step = mass_step + + if "host_group_id" in data: + self._host_matrix, self._data_to_group_mapping = ( + vector_utils.compute_host_matrix(self._data["host_group_id"]) + ) + self._data = vector_utils.format_data_multiple_host( + self._data, self._host_matrix + ) + if jax_installed: + self._host_matrix = BCOO.from_scipy_sparse(self._host_matrix) + def compute_observed_distance_modulus(self, parameter_values_dict): """Compute observed distance modulus from SALT2 fit parameters. @@ -122,11 +164,24 @@ def compute_observed_distance_modulus_variance(self, parameter_values_dict): ) variance_distance_modulus += parameter_values_dict["sigma_M"] ** 2 else: + weights_observation_covariance = jnp.array( + [ + 1.0, + parameter_values_dict["alpha"], + -parameter_values_dict["beta"], + ] + ) + jacobian = jnp.kron( + weights_observation_covariance, + jnp.eye(self._number_datapoints), + ) variance_distance_modulus = ( - self._covariance_observation - + jnp.eye(self._covariance_observation.shape[0]) - * parameter_values_dict["sigma_M"] ** 2 + jacobian @ self._covariance_observation @ jacobian.T ) + variance_distance_modulus += ( + jnp.eye(self._number_datapoints) * parameter_values_dict["sigma_M"] ** 2 + ) + return variance_distance_modulus def give_data_and_variance(self, parameter_values_dict): @@ -152,14 +207,13 @@ def give_data_and_variance(self, parameter_values_dict): * distance_modulus_difference_to_velocity**2 ) else: - A = self._init_A() - J = ( - A[0] - + parameter_values_dict["alpha"] * A[1] - - parameter_values_dict["beta"] * A[2] + conversion_matrix = jnp.diag(distance_modulus_difference_to_velocity) + + velocity_variance = ( + conversion_matrix + @ observed_distance_modulus_variance + @ conversion_matrix.T ) - J = jnp.diag(distance_modulus_difference_to_velocity) @ J - velocity_variance = J @ observed_distance_modulus_variance @ J.T velocities = ( distance_modulus_difference_to_velocity @@ -172,58 +226,3 @@ def give_data_and_variance(self, parameter_values_dict): ) return velocities, velocity_variance - - def _init_A(self): - """Initialize design matrices for linear covariance propagation. - - Returns: - ndarray: Matrix A blocks. - """ - N = len(self._data) - A = jnp.ones((3, N, 3 * N)) - ij = jnp.ogrid[:N, : 3 * N] - for k in range(3): - A[k][ij[1] == 3 * ij[0] + k] = 1 - return A - - def __init__( - self, - data, - h, - covariance_observation=None, - velocity_estimator="full", - mass_step=10, - ): - """Initialize SN Ia velocity vector from SALT2 fits. - - Args: - data (dict): Includes SALT2 parameters and cosmology fields. - h (float): Little-h scaling for distances. - covariance_observation (ndarray|None): Optional observation covariance. - velocity_estimator (str): Estimator name. - mass_step (float): Threshold for host mass step correction. - - Raises: - ValueError: If covariance shape is not `3N x 3N` when provided. - """ - super().__init__(data, covariance_observation=covariance_observation) - self.velocity_estimator = velocity_estimator - self.h = h - self._A = None - self._host_matrix = None - self._mass_step = mass_step - - if "host_group_id" in data: - self._host_matrix, self._data_to_group_mapping = ( - vector_utils.compute_host_matrix(self._data["host_group_id"]) - ) - self._data = vector_utils.format_data_multiple_host( - self._data, self._host_matrix - ) - if jax_installed: - self._host_matrix = BCOO.from_scipy_sparse(self._host_matrix) - - if self._covariance_observation is not None: - if self._covariance_observation.shape != (3 * len(data), 3 * len(data)): - raise ValueError("Cov should be 3N x 3N") - self._A = self._init_A()