Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/bumpversion.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
run: bump2version patch --verbose

- name: Bump major version
if: ${{ ontains(github.event.head_commit.message, '[bump major]') }}
if: ${{contains(github.event.head_commit.message, '[bump major]') }}
run: bump2version major --verbose --tag

- name: Bump minor version
Expand Down
195 changes: 115 additions & 80 deletions flip/data_vector/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ class DataVector(abc.ABC):
_kind (str): One of "velocity", "density" or "cross".
"""

_free_par = []
_kind = "" # 'velocity', 'density' or 'cross'
_needed_keys = []
_free_par = []
_number_dimension_observation_covariance = 0
_parameters_observation_covariance = []

@property
def conditional_free_par(self):
Expand Down Expand Up @@ -109,16 +112,6 @@ def give_data_and_variance(self, **kwargs):
"""
pass

def _check_keys(self, data):
"""Validate that `data` contains all required keys.

Raises:
ValueError: When a required key is missing.
"""
for k in self.needed_keys:
if k not in data:
raise ValueError(f"{k} field is needed in data")

def __init__(self, data, covariance_observation=None, **kwargs):
"""Initialize data vector with data and optional observation covariance.

Expand All @@ -129,6 +122,8 @@ def __init__(self, data, covariance_observation=None, **kwargs):
"""
self._covariance_observation = covariance_observation
self._check_keys(data)
self._number_datapoints = len(data[self.needed_keys[0]])
self.check_covariance_observation()
self._data = copy.copy(data)
self._kwargs = kwargs

Expand All @@ -138,6 +133,31 @@ def __init__(self, data, covariance_observation=None, **kwargs):
if jax_installed:
self.give_data_and_variance_jit = jit(self.give_data_and_variance)

def check_covariance_observation(self):
if self._covariance_observation is not None:
if self._covariance_observation.shape != (
self._number_dimension_observation_covariance * self._number_datapoints,
self._number_dimension_observation_covariance * self._number_datapoints,
):
raise ValueError(
f"Observation covariance matrix should be {self._number_dimension_observation_covariance}N "
f"x {self._number_dimension_observation_covariance}N"
)
log.add(
f"Loading observation covariance matrix, "
f"expecting {self._parameters_observation_covariance} parameters."
)

def _check_keys(self, data):
"""Validate that `data` contains all required keys.

Raises:
ValueError: When a required key is missing.
"""
for k in self.needed_keys:
if k not in data:
raise ValueError(f"{k} field is needed in data")

def get_masked_data_and_cov(self, bool_mask):
"""Return masked data and corresponding masked observation covariance.

Expand Down Expand Up @@ -189,19 +209,31 @@ def compute_covariance(self, model, power_spectrum_dict, **kwargs):
class Dens(DataVector):
_kind = "density"
_needed_keys = ["density", "density_error"]
_free_par = []
_number_dimension_observation_covariance = 1
_parameters_observation_covariance = ["density"]

def give_data_and_variance(self, *args):
"""Return density data and diagonal variance from `density_error`.

Returns:
tuple: (density, density_error^2).
"""

if self._covariance_observation is not None:
return self._data["density"], self._covariance_observation
return self._data["density"], self._data["density_error"] ** 2

def __init__(self, data, covariance_observation=None):
super().__init__(data, covariance_observation=covariance_observation)


class DirectVel(DataVector):
_kind = "velocity"
_needed_keys = ["velocity"]
_free_par = []
_number_dimension_observation_covariance = 1
_parameters_observation_covariance = ["velocity"]

@property
def conditional_needed_keys(self):
Expand Down Expand Up @@ -253,66 +285,12 @@ def __init__(self, data, covariance_observation=None):
self._covariance_observation = velocity_variance


class DensVel(DataVector):
_kind = "cross"

@property
def needed_keys(self):
return self.densities.needed_keys + self.velocities.needed_keys

@property
def free_par(self):
return self.densities.free_par + self.velocities.free_par

def give_data_and_variance(self, *args):
data_density, density_variance = self.densities.give_data_and_variance(*args)
data_velocity, velocity_variance = self.velocities.give_data_and_variance(*args)
data = jnp.hstack((data_density, data_velocity))
variance = jnp.hstack((density_variance, velocity_variance))
return data, variance

def __init__(self, density_vector, velocity_vector):
self.densities = density_vector
self.velocities = velocity_vector

if self.velocities._covariance_observation is not None:
raise NotImplementedError(
"Velocity with observed covariance + density not implemented yet"
)

if jax_installed:
self.give_data_and_variance_jit = jit(self.give_data_and_variance)

def compute_covariance(self, model, power_spectrum_dict, **kwargs):

coords_dens = np.vstack(
(
self.densities.data["ra"],
self.densities.data["dec"],
self.densities.data["rcom_zobs"],
)
)

coords_vel = np.vstack(
(
self.velocities.data["ra"],
self.velocities.data["dec"],
self.velocities.data["rcom_zobs"],
)
)
return CovMatrix.init_from_flip(
model,
"full",
power_spectrum_dict,
coordinates_density=coords_dens,
coordinates_velocity=coords_vel,
**kwargs,
)


class VelFromHDres(DataVector):
_kind = "velocity"
_needed_keys = ["dmu", "zobs"]
_free_par = ["M_0"]
_number_dimension_observation_covariance = 1
_parameters_observation_covariance = ["dmu"]

@property
def conditional_needed_keys(self):
Expand All @@ -331,25 +309,23 @@ def give_data_and_variance(self, parameter_values_dict):
distance_modulus_difference_to_velocity * self._data["dmu"]
- distance_modulus_difference_to_velocity * parameter_values_dict["M_0"]
)
if self._covariance_observation is None and "dmu_error" in self._data:
velocity_error = (
distance_modulus_difference_to_velocity * self._data["dmu_error"]
)
return velocity, velocity_error**2

elif self._covariance_observation is not None:
J = jnp.diag(self._distance_modulus_difference_to_velocity)
velocity_variance = J @ self._covariance_observation @ J.T
return velocity, velocity_variance
if self._covariance_observation is None:
velocity_variance = (
distance_modulus_difference_to_velocity * self._data["dmu_error"]
) ** 2
else:
raise ValueError(
"Cannot compute velocity variance without dmu_error or covariance_observation"
conversion_matrix = jnp.diag(distance_modulus_difference_to_velocity)

velocity_variance = (
conversion_matrix @ self._covariance_observation @ conversion_matrix.T
)

return velocity, velocity_variance

def __init__(
self, data, covariance_observation=None, velocity_estimator="full", **kwargs
):
# Compute conversion using provided input data, not uninitialized self._data

self.velocity_estimator = velocity_estimator

Expand All @@ -360,6 +336,8 @@ class VelFromIntrinsicScatter(DataVector):
_kind = "velocity"
_needed_keys = ["zobs"]
_free_par = ["sigma_M"]
_number_dimension_observation_covariance = 0
_parameters_observation_covariance = []

def give_data_and_variance(self, parameter_values_dict):
distance_modulus_difference_to_velocity = (
Expand Down Expand Up @@ -389,3 +367,60 @@ def give_data_and_variance(self, parameter_values_dict):
def __init__(self, data, velocity_estimator="full"):
super().__init__(data)
self.velocity_estimator = velocity_estimator


class DensVel(DataVector):
_kind = "cross"

@property
def needed_keys(self):
return self.densities.needed_keys + self.velocities.needed_keys

@property
def free_par(self):
return self.densities.free_par + self.velocities.free_par

def give_data_and_variance(self, *args):
data_density, density_variance = self.densities.give_data_and_variance(*args)
data_velocity, velocity_variance = self.velocities.give_data_and_variance(*args)
data = jnp.hstack((data_density, data_velocity))
variance = jnp.hstack((density_variance, velocity_variance))
return data, variance

def __init__(self, density_vector, velocity_vector):
self.densities = density_vector
self.velocities = velocity_vector

if self.velocities._covariance_observation is not None:
raise NotImplementedError(
"Velocity with observed covariance + density not implemented yet"
)

if jax_installed:
self.give_data_and_variance_jit = jit(self.give_data_and_variance)

def compute_covariance(self, model, power_spectrum_dict, **kwargs):

coords_dens = np.vstack(
(
self.densities.data["ra"],
self.densities.data["dec"],
self.densities.data["rcom_zobs"],
)
)

coords_vel = np.vstack(
(
self.velocities.data["ra"],
self.velocities.data["dec"],
self.velocities.data["rcom_zobs"],
)
)
return CovMatrix.init_from_flip(
model,
"full",
power_spectrum_dict,
coordinates_density=coords_dens,
coordinates_velocity=coords_vel,
**kwargs,
)
Loading