Skip to content

Commit

Permalink
restructure simulation function for GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox committed Aug 11, 2024
1 parent 0827abd commit 3c6a4ab
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 68 deletions.
3 changes: 2 additions & 1 deletion fftvis/beams.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
from pyuvsim import AnalyticBeam
from pyuvdata import UVBeam


def _evaluate_beam(
A_s: np.ndarray,
beam: UVBeam,
beam: UVBeam | AnalyticBeam,
az: np.ndarray,
za: np.ndarray,
polarized: bool,
Expand Down
179 changes: 112 additions & 67 deletions fftvis/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,20 @@
import logging
import tracemalloc as tm
from pyuvdata import UVBeam
from pyuvsim import AnalyticBeam

from . import utils, beams, logutils

try:
import cufinufft
import cupy as cp
from cupyx.scipy import interpolate, ndimage

HAVE_CUDA = True
except ImportError:
# if not installed, don't warn
HAVE_CUDA = False

# Default accuracy for the non-uniform fast fourier transform based on precision
default_accuracy_dict = {
1: 6e-8,
Expand Down Expand Up @@ -177,6 +188,8 @@ def simulate(
Tolerance for checking if the array is flat in units of meters. If the
z-coordinate of all baseline vectors is within this tolerance, the array
is considered flat and the z-coordinate is set to zero. Default is 0.0.
use_gpu : bool, default = False
Whether to use the GPU for the simulation. Default is False.
Returns:
-------
Expand Down Expand Up @@ -312,79 +325,18 @@ def simulate(
# Compute uv coordinates
u[:], v[:], w[:] = blx * freqs[fi], bly * freqs[fi], blz * freqs[fi]

# Compute beams - only single beam is supported
A_s = np.zeros((nax, nfeeds, nsim_sources), dtype=complex_dtype)
A_s = beams._evaluate_beam(
A_s,
beam_here,
# Evaluate the RIME
_evaluate_rime(
beam,
az,
za,
polarized,
freqs[fi],
is_beam_complex,
is_coplanar,
spline_opts=beam_spline_opts,
use_gpu=use_gpu
)
A_s = A_s.transpose((1, 0, 2))
beam_product = np.einsum("abs,cbs->acs", A_s.conj(), A_s)
beam_product = beam_product.reshape(nax * nfeeds, nsim_sources)

# Compute sky beam product
i_sky = beam_product * Isky[above_horizon, fi]

# Compute visibilities w/ non-uniform FFT
if is_coplanar:
_vis_here = finufft.nufft2d3(
2 * np.pi * tx,
2 * np.pi * ty,
i_sky,
u,
v,
modeord=0,
eps=eps,
)
else:
_vis_here = finufft.nufft3d3(
2 * np.pi * tx,
2 * np.pi * ty,
2 * np.pi * tz,
i_sky,
u,
v,
w,
modeord=0,
eps=eps,
)

# Expand out the visibility array
_vis[..., fi] = _vis_here.reshape(nfeeds, nfeeds, nbls)

# If beam is complex, we need to compute the reverse negative frequencies
if is_beam_complex and expand_vis:
# Compute
if is_coplanar:
_vis_here_neg = finufft.nufft2d3(
2 * np.pi * tx,
2 * np.pi * ty,
i_sky,
-u,
-v,
modeord=0,
eps=eps,
)
else:
_vis_here_neg = finufft.nufft3d3(
2 * np.pi * tx,
2 * np.pi * ty,
2 * np.pi * tz,
i_sky,
-u,
-v,
-w,
modeord=0,
eps=eps,
)
_vis_negatives[..., fi] = _vis_here_neg.reshape(
nfeeds, nfeeds, nbls
)

# Expand out the visibility array in antenna by antenna matrix
if expand_vis:
Expand Down Expand Up @@ -444,3 +396,96 @@ def simulate(
if polarized
else np.moveaxis(vis[..., 0, 0, :], 2, 0)
)

def _evaluate_rime(
beam: UVBeam | AnalyticBeam,
az: np.ndarray,
za: np.ndarray,
polarized: bool,
freq: float,
is_beam_complex: bool,
is_coplanar: bool,
check: bool = False,
spline_opts: dict = None,
use_gpu: bool = False,
):
"""
"""
if use_gpu:
finufft_2D, finufft_3D = cufinufft.nufft2d3, cufinufft.nufft3d3
else:
finufft_2D, finufft_3D = finufft.nufft2d3, finufft.nufft3d3

# Compute beams - only single beam is supported
A_s = np.zeros((nax, nfeeds, nsim_sources), dtype=complex_dtype)
A_s = beams._evaluate_beam(
A_s,
beam_here,
az,
za,
polarized,
freqs[fi],
spline_opts=spline_opts,
)
A_s = A_s.transpose((1, 0, 2))
beam_product = np.einsum("abs,cbs->acs", A_s.conj(), A_s)
beam_product = beam_product.reshape(nax * nfeeds, nsim_sources)

# Compute sky beam product
i_sky = beam_product * Isky[above_horizon, fi]

# Compute visibilities w/ non-uniform FFT
if is_coplanar:
_vis_here = finufft.nufft2d3(
2 * np.pi * tx,
2 * np.pi * ty,
i_sky,
u,
v,
modeord=0,
eps=eps,
)
else:
_vis_here = finufft.nufft3d3(
2 * np.pi * tx,
2 * np.pi * ty,
2 * np.pi * tz,
i_sky,
u,
v,
w,
modeord=0,
eps=eps,
)

# Expand out the visibility array
_vis[..., fi] = _vis_here.reshape(nfeeds, nfeeds, nbls)

# If beam is complex, we need to compute the reverse negative frequencies
if is_beam_complex and expand_vis:
# Compute
if is_coplanar:
_vis_here_neg = finufft.nufft2d3(
2 * np.pi * tx,
2 * np.pi * ty,
i_sky,
-u,
-v,
modeord=0,
eps=eps,
)
else:
_vis_here_neg = finufft.nufft3d3(
2 * np.pi * tx,
2 * np.pi * ty,
2 * np.pi * tz,
i_sky,
-u,
-v,
-w,
modeord=0,
eps=eps,
)
_vis_negatives[..., fi] = _vis_here_neg.reshape(
nfeeds, nfeeds, nbls
)

0 comments on commit 3c6a4ab

Please sign in to comment.