Skip to content

Commit

Permalink
* fix strange refactor error
Browse files Browse the repository at this point in the history
  • Loading branch information
Joshuaalbert committed Sep 27, 2024
1 parent e3ff199 commit ebdb70f
Show file tree
Hide file tree
Showing 45 changed files with 176 additions and 137 deletions.
3 changes: 1 addition & 2 deletions debug/psf_optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from jaxns import Prior
from tomographic_kernel.frames import ENU

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.assets.content_registry import fill_registries
from dsa2000_cal.assets.registries import array_registry
from dsa2000_cal.common.quantity_utils import quantity_to_jnp
Expand Down Expand Up @@ -129,7 +128,7 @@ def compute_dynamic_range(antenna_locations: jax.Array, lm: jax.Array, freq: jax
antennas_itrs = antennas.get_itrs(obstime=obstime, location=location)
antennas_enu = antennas_itrs.transform_to(ENU(obstime=obstime, location=location))

antennas_enu_xyz = dsa2000_cal.common.mixed_precision_utils.T # [n, 3]
antennas_enu_xyz = antennas_enu.cartesian.xyz.T # [n, 3]
min_east = jnp.min(antennas_enu_xyz[:, 0])
max_east = jnp.max(antennas_enu_xyz[:, 0])
d_east = (max_east - min_east) / 10
Expand Down
10 changes: 10 additions & 0 deletions dsa2000_cal/dsa2000_cal/assets/arrays/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,13 @@ def get_antenna_model(self) -> AbstractAntennaModel:
antenna beam
"""
...

@abstractmethod
def integration_time(self) -> au.Quantity:
"""
Get integration time (s)
Returns:
integration time
"""
...
5 changes: 4 additions & 1 deletion dsa2000_cal/dsa2000_cal/assets/arrays/dsa2000W/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ class DSA2000WArray(AbstractArray):
DSA2000W array class.
"""

def integration_time(self) -> au.Quantity:
return 1.5 * au.s

def get_channel_width(self) -> au.Quantity:
return (2000e6 * au.Hz - 700e6 * au.Hz) / 8000
return 1300 * au.MHz / 10000

def get_array_location(self) -> ac.EarthLocation:
return mean_itrs(self.get_antennas().get_itrs()).earth_location
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from astropy import coordinates as ac

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.abc import AbstractAntennaModel
from dsa2000_cal.antenna_model.antenna_beam import AltAzAntennaModel
from dsa2000_cal.assets.arrays.dsa2000W.array import DSA2000WArray
Expand Down Expand Up @@ -102,7 +101,7 @@ def _get_antennas(self) -> ac.EarthLocation:
all_antennas = array.get_antennas()
array_centre = array.get_array_location()
all_antennas_itrs = all_antennas.get_itrs()
all_antennas_itrs_xyz = dsa2000_cal.common.mixed_precision_utils.T
all_antennas_itrs_xyz = all_antennas_itrs.T
max_baseline = np.max(
np.linalg.norm(
all_antennas_itrs_xyz[:, None, :] - all_antennas_itrs_xyz[None, :, :],
Expand Down
3 changes: 3 additions & 0 deletions dsa2000_cal/dsa2000_cal/assets/arrays/lwa/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class LWAArray(AbstractArray):
LWA array class.
"""

def integration_time(self) -> au.Quantity:
return 10. * au.s

def get_channel_width(self) -> au.Quantity:
return 23913.3199056 * au.Hz

Expand Down
3 changes: 1 addition & 2 deletions dsa2000_cal/dsa2000_cal/assets/arrays/lwa_mock/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
from astropy import coordinates as ac

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.abc import AbstractAntennaModel
from dsa2000_cal.antenna_model.antenna_beam import AltAzAntennaModel
from dsa2000_cal.assets.arrays.dsa2000W.array import DSA2000WArray
Expand Down Expand Up @@ -103,7 +102,7 @@ def _get_antennas(self) -> ac.EarthLocation:
all_antennas = array.get_antennas()
array_centre = array.get_array_location()
all_antennas_itrs = all_antennas.get_itrs()
all_antennas_itrs_xyz = dsa2000_cal.common.mixed_precision_utils.T
all_antennas_itrs_xyz = all_antennas_itrs.cartesian.xyz.T
max_baseline = np.max(
np.linalg.norm(
all_antennas_itrs_xyz[:, None, :] - all_antennas_itrs_xyz[None, :, :],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from jaxopt._src.linear_solve import solve_qr
from jaxopt._src.tree_util import tree_l2_norm, tree_inf_norm

import dsa2000_cal.common.mixed_precision_utils


class LevenbergMarquardtState(NamedTuple):
Expand Down Expand Up @@ -199,7 +198,7 @@ def init_state(self, init_params: Any, *args,

if self.materialize_jac:
jac = self._jac_fun(init_params, *args, **kwargs)
jt = dsa2000_cal.common.mixed_precision_utils.T
jt = jac.T
jtj = jt @ jac
gradient = jt @ residual
damping_factor = self.damping_parameter * jnp.max(jnp.diag(jtj))
Expand Down Expand Up @@ -256,7 +255,7 @@ def gain_ratio_test_true_func(params, damping_factor,
if self.materialize_jac:
# Calculate Jacobian and it's transpose based on the updated coeffs.
jac = self._jac_fun(params, *args, **kwargs)
jt = dsa2000_cal.common.mixed_precision_utils.T
jt = jac.T
# J^T.J is the gauss newton approximate hessian.
jtj = jt @ jac
gradient = jt @ residual
Expand Down Expand Up @@ -326,7 +325,7 @@ def update_state_using_delta_params(self, loss_curr, params, delta_params,

# Calculate denominator of the gain ratio based on Eq. 6.16, "Introduction
# to optimization and data fitting", L(0)-L(hlm)=0.5*hlm^T*(mu*hlm-g).
gain_ratio_denom = 0.5 * dsa2000_cal.common.mixed_precision_utils.T @ (
gain_ratio_denom = 0.5 * delta_params.T @ (
damping_factor * delta_params - gradient)

# Current value of loss function F=0.5*||f||^2.
Expand Down Expand Up @@ -529,8 +528,8 @@ def _jtj_op(self, params, vec, *args, **kwargs):

def _jtj_diag_op(self, params, *args, **kwargs):
"""Diagonal elements of J^T.J, where J is jacobian of fun at params."""
diag_op = lambda v: dsa2000_cal.common.mixed_precision_utils.T @ self._jtj_op(params, v, *args, **kwargs)
return dsa2000_cal.common.mixed_precision_utils.T
diag_op = lambda v: v.T @ self._jtj_op(params, v, *args, **kwargs)
return jax.vmap(diag_op)(jnp.eye(len(params))).T

def _d2fvv_op(self, primals, tangents1, tangents2, *args, **kwargs):
"""Product with d2f.v1v2."""
Expand Down
11 changes: 5 additions & 6 deletions dsa2000_cal/dsa2000_cal/common/astropy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from astropy.coordinates import Angle
from astropy.coordinates.angles import offset_by

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.coord_utils import lmn_to_icrs


Expand Down Expand Up @@ -156,7 +155,7 @@ def mean_icrs(coords: ac.ICRS) -> ac.ICRS:
Returns:
the mean ITRS coordinate
"""
mean_coord = dsa2000_cal.common.mixed_precision_utils.T.mean(axis=0)
mean_coord = coords.cartesian.xyz.T.mean(axis=0)
mean_coord /= np.linalg.norm(mean_coord)
spherical = ac.ICRS(mean_coord, representation_type='cartesian').spherical
return ac.ICRS(ra=spherical.lon, dec=spherical.lat)
Expand Down Expand Up @@ -256,13 +255,13 @@ def fibonacci_celestial_sphere(n: int) -> ac.ICRS:

return ac.ICRS(lon * au.rad, lat * au.rad)


@pytest.mark.parametrize('n', [10, 100, 1000])
def test_fibonacci_celestial_sphere(n:int):
def test_fibonacci_celestial_sphere(n: int):
pointings = fibonacci_celestial_sphere(n=n)
import pylab as plt
plt.scatter(pointings.ra, pointings.dec, s=1)
plt.show()

mean_area = (4*np.pi / n) * au.rad**2
print(n, mean_area.to('deg^2'))

mean_area = (4 * np.pi / n) * au.rad ** 2
print(n, mean_area.to('deg^2'))
3 changes: 1 addition & 2 deletions dsa2000_cal/dsa2000_cal/common/coord_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from astropy.units import Quantity
from tomographic_kernel.frames import ENU

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.quantity_utils import quantity_to_jnp
from dsa2000_cal.delay_models.uvw_utils import perley_icrs_from_lmn, perley_lmn_from_icrs

Expand Down Expand Up @@ -82,7 +81,7 @@ def earth_location_to_uvw_approx(antennas: EarthLocation, obs_time: at.Time, pha
antennas_uvw = antennas_gcrs.transform_to(frame_uvw)

w, u, v = antennas_uvw.cartesian.xyz
uvw = dsa2000_cal.common.mixed_precision_utils.T
uvw = ac.CartesianRepresentation(u, v, w).xyz.T
uvw = uvw.reshape(shape + (3,))
return uvw

Expand Down
3 changes: 1 addition & 2 deletions dsa2000_cal/dsa2000_cal/common/fits_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from astropy.wcs import WCS
from scipy.ndimage import zoom

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.quantity_utils import quantity_to_np
from dsa2000_cal.common.serialise_utils import SerialisableBaseModel

Expand Down Expand Up @@ -130,7 +129,7 @@ def transform_to_wsclean_model(fits_file: str, output_file: str, pointing_centre
# Apply perm. Note: because python is column-major we need to reverse the perm
# print(perm)

data = np.transpose(dsa2000_cal.common.mixed_precision_utils.T, perm).T.copy() # [Ns, Nf, Ndec, Nra]
data = np.transpose(hdu[0].data.T, perm).T.copy() # [Ns, Nf, Ndec, Nra]
# print(data.shape)
Ns, Nf, Ndec, Nra = data.shape
new_wcs.wcs.crval = [pointing_centre.ra.deg, pointing_centre.dec.deg, ref_freq_hz, 1]
Expand Down
7 changes: 3 additions & 4 deletions dsa2000_cal/dsa2000_cal/common/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
import matplotlib.pyplot as plt
import numpy as np

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.types import SystemGains
from dsa2000_cal.calibration.probabilistic_models.gains_per_facet_model import CalibrationSolutions
from dsa2000_cal.types import SystemGains


def figs_to_gif(fig_generator, gif_path, duration=0.5, loop=0, dpi=80):
Expand Down Expand Up @@ -87,7 +86,7 @@ def plot_antenna_gains(gain_obj: SystemGains | CalibrationSolutions, antenna_idx
# 0,0 -> 0, 1,0 -> 1, 0,1 -> 2, 1,1 -> 3
row = 2 * q + p
axs[row][0].imshow(
dsa2000_cal.common.mixed_precision_utils.T,
amplitude[:, :, p, q].T,
aspect='auto',
origin='lower',
cmap='viridis',
Expand All @@ -100,7 +99,7 @@ def plot_antenna_gains(gain_obj: SystemGains | CalibrationSolutions, antenna_idx
axs[row, 0].set_ylabel("Frequency (Hz)")

axs[row][1].imshow(
dsa2000_cal.common.mixed_precision_utils.T,
phase[:, :, p, q].T,
aspect='auto',
origin='lower',
cmap='hsv',
Expand Down
6 changes: 3 additions & 3 deletions dsa2000_cal/dsa2000_cal/common/tests/test_astropy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from astropy import coordinates as ac, units as au
from matplotlib import pyplot as plt

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.astropy_utils import random_discrete_skymodel, mean_icrs, \
create_spherical_grid, create_spherical_earth_grid, create_random_spherical_layout

Expand Down Expand Up @@ -61,14 +60,15 @@ def test_create_spherical_earth_grid():
plt.scatter(grid_earth.geodetic[0], grid_earth.geodetic[1], marker='o')
# plt.scatter(center.geodetic[0], center.geodetic[1], marker='x')
plt.show()
assert np.linalg.norm(dsa2000_cal.common.mixed_precision_utils.T - center.get_itrs().cartesian.xyz,
assert np.linalg.norm(grid_earth.get_itrs().cartesian.xyz.T - center.get_itrs().cartesian.xyz,
axis=-1).max() <= radius

print(len(grid_earth))


def test_create_spherical_grid_all_sky():
grid = create_random_spherical_layout(10000)
plt.scatter(grid.ra.rad, grid.dec.rad, marker='o', alpha=0.1)
plt.show()

assert len(grid) == 10000
assert len(grid) == 10000
7 changes: 3 additions & 4 deletions dsa2000_cal/dsa2000_cal/common/tests/test_coord_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from astropy.wcs import WCS
from tomographic_kernel.frames import ENU

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.coord_utils import earth_location_to_uvw_approx, icrs_to_lmn, lmn_to_icrs, earth_location_to_enu, \
icrs_to_enu, enu_to_icrs, lmn_to_icrs_old
from dsa2000_cal.common.quantity_utils import quantity_to_jnp
Expand Down Expand Up @@ -192,7 +191,7 @@ def test_earth_location_to_enu():
array_location = ac.EarthLocation.of_site('vla')
time = at.Time('2000-01-01T00:00:00', format='isot')
enu = earth_location_to_enu(antennas, array_location, time)
assert np.linalg.norm(dsa2000_cal.common.mixed_precision_utils.T) < 6400 * au.km
assert np.linalg.norm(enu.cartesian.xyz.T) < 6400 * au.km

enu_frame = ENU(location=ac.EarthLocation.of_site('vla'), obstime=time)
n = 500
Expand All @@ -202,7 +201,7 @@ def test_earth_location_to_enu():
up=np.random.uniform(size=(n,), low=-5, high=5) * au.km,
frame=enu_frame
).transform_to(ac.ITRS).earth_location
enu = dsa2000_cal.common.mixed_precision_utils.T
enu = earth_location_to_enu(antennas, array_location, time).cartesian.xyz.T
# print(enu)

dist = np.linalg.norm(enu[:, None, :] - enu[None, :, :], axis=-1)
Expand All @@ -215,7 +214,7 @@ def test_icrs_to_enu():
time = at.Time('2000-01-01T00:00:00', format='isot')
enu = icrs_to_enu(sources, array_location, time)
print(enu)
np.testing.assert_allclose(np.linalg.norm(dsa2000_cal.common.mixed_precision_utils.T), 1.)
np.testing.assert_allclose(np.linalg.norm(enu.cartesian.xyz.T), 1.)

reconstruct_sources = enu_to_icrs(enu)
print(reconstruct_sources)
Expand Down
5 changes: 3 additions & 2 deletions dsa2000_cal/dsa2000_cal/common/tests/test_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_calc_noise_full_observation():
def test_calc_noise_8000chan_1hour():
num_antennas = 2048
system_equivalent_flux_density = 5022. # Jy
chan_width_hz = 130000.0 # Hz
chan_width_hz = 162500.0 # Hz
t_int_s = 1.5 # s
num_channels = 8000
num_integrations = 3600. / t_int_s
Expand All @@ -31,6 +31,7 @@ def test_calc_noise_8000chan_1hour():
flag_frac=flag_frac,
num_pol=2
)
print(chan_width_hz * num_channels)
print(f"Image noise (1h 35% flagged): {image_noise} Jy / beam")
assert np.isclose(image_noise, 1e-6, atol=1e-6)

Expand All @@ -40,7 +41,7 @@ def test_calc_noise_8000chan_1hour():
t_int_s=t_int_s
)
print(f"Baseline noise: {baseline_noise} Jy")
assert np.isclose(baseline_noise, 11.37, atol=1e-1)
assert np.isclose(baseline_noise, 10.17, atol=1e-1)


def test_calc_noise_40chan_6s_dsa():
Expand Down
9 changes: 4 additions & 5 deletions dsa2000_cal/dsa2000_cal/common/tests/test_vec_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import numpy as np
from jax import numpy as jnp

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.vec_utils import vec, unvec, kron_product, kron_inv


Expand Down Expand Up @@ -56,17 +55,17 @@ def f(a, b, c):
print()
print("a @ b @ c.T.conj")
print("Naive a.b.c | unvec(kron(c.T, a).vec(b))")
a1 = jax.jit(lambda a, b, c: f(a, b, dsa2000_cal.common.mixed_precision_utils.T)).lower(a, b, c).compile().cost_analysis()[0]
a2 = jax.jit(lambda a, b, c: kron_product(a, b, dsa2000_cal.common.mixed_precision_utils.T)).lower(a, b, c).compile().cost_analysis()[0]
a1 = jax.jit(lambda a, b, c: f(a, b, c.conj().T)).lower(a, b, c).compile().cost_analysis()[0]
a2 = jax.jit(lambda a, b, c: kron_product(a, b, c.conj().T)).lower(a, b, c).compile().cost_analysis()[0]
for key in ['bytes accessed', 'flops', 'utilization operand 0 {}', 'utilization operand 1 {}',
'utilization operand 2 {}', 'bytes accessed output {}']:
print(key, ":", a1.get(key, None), a2.get(key, None))

print()
print("a @ b @ c.conj.T")
print("Naive a.b.c | unvec(kron(c.T, a).vec(b))")
a1 = jax.jit(lambda a, b, c: f(a, b, dsa2000_cal.common.mixed_precision_utils.T)).lower(a, b, c).compile().cost_analysis()[0]
a2 = jax.jit(lambda a, b, c: kron_product(a, b, dsa2000_cal.common.mixed_precision_utils.T)).lower(a, b, c).compile().cost_analysis()[0]
a1 = jax.jit(lambda a, b, c: f(a, b, c.conj().T)).lower(a, b, c).compile().cost_analysis()[0]
a2 = jax.jit(lambda a, b, c: kron_product(a, b, c.conj().T)).lower(a, b, c).compile().cost_analysis()[0]
for key in ['bytes accessed', 'flops', 'utilization operand 0 {}', 'utilization operand 1 {}',
'utilization operand 2 {}', 'bytes accessed output {}']:
print(key, ":", a1.get(key, None), a2.get(key, None))
Expand Down
3 changes: 1 addition & 2 deletions dsa2000_cal/dsa2000_cal/common/tests/test_wgridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
from jax import numpy as jnp

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.jax_utils import multi_vmap, convert_to_ufunc
from dsa2000_cal.common.mixed_precision_utils import mp_policy
from dsa2000_cal.common.wgridder import vis_to_image, image_to_vis
Expand Down Expand Up @@ -118,7 +117,7 @@ def test_spectral_predict(center_offset: float):
)
)(dirty, dl, dm, l0, m0, freqs)
assert np.shape(vis) == (num_freqs, len(uvw), 1)
vis = dsa2000_cal.common.mixed_precision_utils.T # [num_rows, chan]
vis = vis.T # [num_rows, chan]
assert np.all(vis[:, 0] == vis[:, 1])

dirty_rec = vis_to_image(
Expand Down
5 changes: 2 additions & 3 deletions dsa2000_cal/dsa2000_cal/delay_models/far_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from astropy import coordinates as ac, time as at, units as au, constants as const
from jax import config, numpy as jnp, lax

import dsa2000_cal.common.mixed_precision_utils
from dsa2000_cal.common.interp_utils import InterpolatedArray
from dsa2000_cal.common.jax_utils import multi_vmap
from dsa2000_cal.common.quantity_utils import quantity_to_jnp
from dsa2000_cal.common.mixed_precision_utils import mp_policy
from dsa2000_cal.common.quantity_utils import quantity_to_jnp
from dsa2000_cal.delay_models.uvw_utils import perley_icrs_from_lmn, celestial_to_cartesian, norm, norm2


Expand Down Expand Up @@ -72,7 +71,7 @@ def __post_init__(self):
if self.resolution is None:
# compute max baseline
antenna_1, antenna_2 = np.asarray(list(itertools.combinations(range(len(self.antennas)), 2))).T
antennas_itrs = dsa2000_cal.common.mixed_precision_utils.T
antennas_itrs = self.antennas.get_itrs().cartesian.xyz.T
max_baseline = np.max(np.linalg.norm(antennas_itrs[antenna_2] - antennas_itrs[antenna_1], axis=-1))
# Select resolution to keep interpolation error below 1 mm
if max_baseline <= 10 * au.km:
Expand Down
Loading

0 comments on commit ebdb70f

Please sign in to comment.