Skip to content

Commit

Permalink
Overhaul coordinate handling
Browse files Browse the repository at this point in the history
  • Loading branch information
andreicuceu committed Nov 18, 2023
1 parent d492a9b commit 7c2bcc4
Show file tree
Hide file tree
Showing 10 changed files with 513 additions and 352 deletions.
10 changes: 6 additions & 4 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@ def test_data():

assert np.allclose(data.data_vec, hdul[1].data['DA'])

rp_rt_grid = corr_item.rp_rt_grid
assert np.allclose(rp_rt_grid[0], hdul[1].data['RP'])
assert np.allclose(rp_rt_grid[1], hdul[1].data['RT'])
assert np.allclose(corr_item.z_grid, hdul[1].data['Z'])
rp_grid = corr_item.model_coordinates.rp_grid
rt_grid = corr_item.model_coordinates.rt_grid
z_grid = corr_item.model_coordinates.z_grid
assert np.allclose(rp_grid, hdul[1].data['RP'])
assert np.allclose(rt_grid, hdul[1].data['RT'])
assert np.allclose(z_grid, hdul[1].data['Z'])

hdul.close()

Expand Down
141 changes: 141 additions & 0 deletions vega/coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import numpy as np


class Coordinates:
"""Class to handle Vega coordinate grids
"""

def __init__(self, rp_min, rp_max, rt_max, rp_nbins, rt_nbins,
rp_grid=None, rt_grid=None, z_grid=None, z_eff=None):
"""Initialize the coordinate grids.
Parameters
----------
rp_min : float
Minimum rp
rp_max : float
Maximum rp
rt_max : float
Maximum rt
rp_nbins : float
Number of rp bins
rt_nbins : float
Number of rt bins
rp_grid : Array , optional
rp grid, by default None
rt_grid : Array, optional
rt grid, by default None
z_grid : Array, optional
z grid, by default None
"""
self.rp_min = rp_min
self.rp_max = rp_max
self.rt_max = rt_max
self.rp_nbins = rp_nbins
self.rt_nbins = rt_nbins

self.rp_binsize = (rp_max - rp_min) / rp_nbins
self.rt_binsize = rt_max / rt_nbins

rp_regular_grid = np.arange(rp_min + self.rp_binsize / 2, rp_max, self.rp_binsize)
rt_regular_grid = np.arange(self.rt_binsize / 2, rt_max, self.rt_binsize)

rt_regular_grid, rp_regular_grid = np.meshgrid(rt_regular_grid, rp_regular_grid)
self.rp_regular_grid = rp_regular_grid.flatten()
self.rt_regular_grid = rt_regular_grid.flatten()

self.rp_grid = self.rp_regular_grid if rp_grid is None else rp_grid
self.rt_grid = self.rt_regular_grid if rt_grid is None else rt_grid

self.r_grid = np.sqrt(self.rp_grid**2 + self.rt_grid**2)
self.r_regular_grid = np.sqrt(self.rp_regular_grid**2 + self.rt_regular_grid**2)

self.mu_grid = np.zeros_like(self.r_grid)
w = self.r_grid > 0.
self.mu_grid[w] = self.rp_grid[w] / self.r_grid[w]

self.mu_regular_grid = np.zeros_like(self.r_regular_grid)
w = self.r_regular_grid > 0.
self.mu_regular_grid[w] = self.rp_regular_grid[w] / self.r_regular_grid[w]

if z_grid is None and z_eff is None:
self.z_grid = None
else:
self.z_grid = z_eff if z_grid is None else z_grid

@classmethod
def init_from_grids(cls, other, rp_grid, rt_grid, z_grid):
"""Initialize from other coordinates and new grids
Parameters
----------
other : Coordinates
Other coordinates
rp_grid : Array
rp grid
rt_grid : Array
rt grid
z_grid : Array
z grid
Returns
-------
Coordinates
New coordinates
"""
return cls(
other.rp_min, other.rp_max, other.rt_max, other.rp_nbins, other.rt_nbins,
rp_grid=rp_grid, rt_grid=rt_grid, z_grid=z_grid
)

def get_mask_to_other(self, other):
"""Build mask from the current coordinates to the other coordinates.
Parameters
----------
other : Coordinates
Other coordinates
Returns
-------
Array
Mask
"""
assert self.rp_binsize == other.rp_binsize
assert self.rt_binsize == other.rt_binsize
mask = (self.rp_grid >= other.rp_min) & (self.rp_grid <= other.rp_max)
mask &= (self.rt_grid <= other.rt_max)
return mask

def get_mask_scale_cuts(self, cuts_config):
"""Build mask to apply scale cuts
Parameters
----------
cuts_config : ConfigParser
Cuts section from config
Returns
-------
Array
Mask
"""
# Read the cuts
rp_min_cut = cuts_config.getfloat('rp-min', 0.)
rp_max_cut = cuts_config.getfloat('rp-max', 200.)

rt_min_cut = cuts_config.getfloat('rt-min', 0.)
rt_max_cut = cuts_config.getfloat('rt-max', 200.)

r_min_cut = cuts_config.getfloat('r-min', 10.)
r_max_cut = cuts_config.getfloat('r-max', 180.)

mu_min_cut = cuts_config.getfloat('mu-min', -1.)
mu_max_cut = cuts_config.getfloat('mu-max', +1.)

mask = (self.rp_regular_grid > rp_min_cut) & (self.rp_regular_grid < rp_max_cut)
mask &= (self.rt_regular_grid > rt_min_cut) & (self.rt_regular_grid < rt_max_cut)
mask &= (self.r_regular_grid > r_min_cut) & (self.r_regular_grid < r_max_cut)
mask &= (self.mu_regular_grid > mu_min_cut) & (self.mu_regular_grid < mu_max_cut)

return mask
13 changes: 6 additions & 7 deletions vega/correlation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from scipy.integrate import quad
from scipy.interpolate import interp1d
from astropy.table import Table
from pkg_resources import resource_exists, resource_filename

from . import utils

Expand All @@ -17,7 +16,7 @@ class CorrelationFunction:
Extensions should have their separate method of the form
'compute_extension' that can be called from outside
"""
def __init__(self, config, fiducial, coords_grid, scale_params,
def __init__(self, config, fiducial, coordinates, scale_params,
tracer1, tracer2, bb_config=None, metal_corr=False):
"""
Expand All @@ -27,8 +26,8 @@ def __init__(self, config, fiducial, coords_grid, scale_params,
model section of config file
fiducial : dict
fiducial config
coords_grid : dict
Dictionary with coordinate grid - r, mu, z
coordinates : Coordinates
Vega coordinates object
scale_params : ScaleParameters
ScaleParameters object
tracer1 : dict
Expand All @@ -41,9 +40,9 @@ def __init__(self, config, fiducial, coords_grid, scale_params,
Whether this is a metal correlation, by default False
"""
self._config = config
self._r = coords_grid['r']
self._mu = coords_grid['mu']
self._z = coords_grid['z']
self._r = coordinates.r_grid
self._mu = coordinates.mu_grid
self._z = coordinates.z_grid
self._multipole = config.getint('single_multipole', -1)
self._tracer1 = tracer1
self._tracer2 = tracer2
Expand Down
Loading

0 comments on commit 7c2bcc4

Please sign in to comment.