Skip to content

Commit

Permalink
Incremental nufft (#278)
Browse files Browse the repository at this point in the history
* method in lform and lwcs, since it's occasionally useful when associating m and lx

* More fourier interpolation, including incremental u2nu

* Improved infer_bin_edges

* Added bench.py

* regreplace

* remove commented-out wip code

---------

Co-authored-by: Sigurd K Naess <sigurdkn@login34.chn.perlmutter.nersc.gov>
  • Loading branch information
amaurea and Sigurd K Naess authored Dec 12, 2024
1 parent b5c4c1d commit 3c79a01
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 101 deletions.
85 changes: 85 additions & 0 deletions pixell/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import time as _time
from contextlib import contextmanager
from . import bunch

"""bench: Simple timing of python code blocks.
Example usage:
1. Manual printing
from pixell import bench
for i in range(nfile):
with bench.mark("all"):
with bench.mark("read"):
a = np.loadtxt(afiles[i])
b = np.loadtxt(bfiles[i])
with bench.mark("sum"):
a += b
with bench.mark("write"):
np.savetxt(ofiles[i], a)
print("Processed case %d in %7.4f s. read %7.4f sum %7.4f write %7.4f" % (i, bench.t.all, bench.t.read, bench.t.sum, bench.t.write))
print("Total %7.4f s. read %7.4f sum %7.4f write %7.4f" % (i, bench.t_tot.all, bench.t_tot.read, bench.t_tot.sum, bench.t_tot.write))
2. Quick-and-dirty printing
from pixell import bench
for i in range(nfile):
with bench.show("read"):
a = np.loadtxt(afiles[i])
b = np.loadtxt(bfiles[i])
with bench.show("sum"):
a += b
with bench.show("write"):
np.savetxt(ofiles[i], a)
bench.show is equivalent to bench.mark, just with an extra print.
This means that bench.show updates .ttot and .n just like bench.mark
does.
The examples above collect statistics globally. You can create local
benchmark objects with bench.Bench(). Example:
from pixell import bench
mybench = bench.Bench()
with mybench.mark("example"):
do_something()
The overhead of bench.mark is around 3 µs.
"""

# Just wall times for now, but could be extended to measure
# cpu time or leaked memory
class Bench:
def __init__(self):
self.t_tot = bunch.Bunch()
self.t = bunch.Bunch()
self.n = bunch.Bunch()
@contextmanager
def mark(self, name):
if name not in self.n:
self.t_tot[name] = 0
self.n[name] = 0
t1 = _time.time()
try:
yield
finally:
t2 = _time.time()
self.n[name] += 1
self.t [name] = t2-t1
self.t_tot[name] += t2-t1
@contextmanager
def show(self, name):
try:
with self.mark(name):
yield
finally:
print("%7.4f s (last) %7.4f s (mean) %4d (n) %s" % (self.t[name], self.t_tot[name]/self.n[name], self.n[name], name))

# Global interface
_default = Bench()
mark = _default.mark
show = _default.show
t_tot = _default.t_tot
t = _default.t
n = _default.n
4 changes: 4 additions & 0 deletions pixell/curvedsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ def __init__(self, lmax=None, mmax=None, nalm=None, stride=1, layout="triangular
if nalm is not None:
assert self.nelem == nalm, "lmax must be explicitly specified when lmax != mmax"
self.mstart= mstart.astype(np.uint64, copy=False)
@property
def nl(self): return self.lmax+1
@property
def nm(self): return self.mmax+1
def lm2ind(self, l, m):
return (self.mstart[m].astype(int, copy=False)+l*self.stride).astype(int, copy=False)
def get_map(self):
Expand Down
203 changes: 144 additions & 59 deletions pixell/enmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def posaxes(self, safe=True, corner=False, dtype=np.float64): return posaxes(sel
def pixmap(self): return pixmap(self.shape, self.wcs)
def laxes(self, oversample=1, method="auto"): return laxes(self.shape, self.wcs, oversample=oversample, method=method)
def lmap(self, oversample=1): return lmap(self.shape, self.wcs, oversample=oversample)
def lform(self): return lform(self)
def lform(self, method="auto"): return lform(self, method=method)
def modlmap(self, oversample=1, min=0): return modlmap(self.shape, self.wcs, oversample=oversample, min=min)
def modrmap(self, ref="center", safe=True, corner=False): return modrmap(self.shape, self.wcs, ref=ref, safe=safe, corner=corner)
def lbin(self, bsize=None, brel=1.0, return_nhit=False, return_bins=False, lop=None): return lbin(self, bsize=bsize, brel=brel, return_nhit=return_nhit, return_bins=return_bins, lop=lop)
Expand All @@ -99,12 +99,12 @@ def npix(self): return np.prod(self.shape[-2:])
@property
def geometry(self): return self.shape, self.wcs
def resample(self, oshape, off=(0,0), method="fft", border="wrap", corner=False, order=3): return resample(self, oshape, off=off, method=method, border=border, corner=corner, order=order)
def project(self, shape, wcs, order=3, border="constant", cval=0, safe=True): return project(self, shape, wcs, order, border=border, cval=cval, safe=safe)
def project(self, shape, wcs, mode="spline", order=3, border="constant", cval=0, safe=True): return project(self, shape, wcs, order, mode=mode, border=border, cval=cval, safe=safe)
def extract(self, shape, wcs, omap=None, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None, reverse=False): return extract(self, shape, wcs, omap=omap, wrap=wrap, op=op, cval=cval, iwcs=iwcs, reverse=reverse)
def extract_pixbox(self, pixbox, omap=None, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None, reverse=False): return extract_pixbox(self, pixbox, omap=omap, wrap=wrap, op=op, cval=cval, iwcs=iwcs, reverse=reverse)
def insert(self, imap, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None): return insert(self, imap, wrap=wrap, op=op, cval=cval, iwcs=iwcs)
def insert_at(self, pix, imap, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None): return insert_at(self, pix, imap, wrap=wrap, op=op, cval=cval, iwcs=iwcs)
def at(self, pos, order=3, border="constant", cval=0.0, unit="coord", safe=True): return at(self, pos, order, border=border, cval=0, unit=unit, safe=safe)
def at(self, pos, mode="spline", order=3, border="constant", cval=0.0, unit="coord", safe=True, ip=None): return at(self, pos, mode=mode, order=order, border=border, cval=0, unit=unit, safe=safe, ip=ip)
def argmax(self, unit="coord"): return argmax(self, unit=unit)
def autocrop(self, method="plain", value="auto", margin=0, factors=None, return_info=False): return autocrop(self, method, value, margin, factors, return_info)
def apod(self, width, profile="cos", fill="zero"): return apod(self, width, profile=profile, fill=fill)
Expand Down Expand Up @@ -546,11 +546,51 @@ def contains(shape, wcs, pos, unit="coord"):
else: pix = pos
return np.all((pix>=0)&(pix.T<shape[-2:]).T,0)

def project(map, shape, wcs, order=3, border="constant", cval=0.0, force=False, safe=True, bsize=1000, ip=None):
"""Project the map into a new map given by the specified
shape and wcs, interpolating as necessary.
This uses local interpolation, and will lose information
when downgrading compared to averaging down."""
def project(map, shape, wcs, mode="spline", order=3, border="constant",
cval=0.0, force=False, safe=True, bsize=1000, context=50, ip=None):
"""Project map to a new geometry.
This function is not suited for going down in resolution, because
only interpolation is done, not averaging. This means that if
the output geometry has lower resolution than the input, then
information will be lost because noise will not average down the
way it optimally would.
* map: enmap.ndmap of shape [...,ny,nx]
* shape, wcs: The geometry to project to
* mode: The interpolation mode. Same meaning as in utils.interpol.
Valid values are "nearest", "linear", "cubic", "spline" and "fourier".
"nearest" and "linear" are local interpolations, where one does
not need to worry about edge effects and ringing. "cubic" and
especially "fourier" are sensitive to the boundary conditions,
and maps may need to be apodized first. Only "fourier"
interpolation preserves map power on all scales. The other
types lose a bit of power at high multipoles.
fourier > cubic > linear > nearest for high-l fidelity.
"spline" is a generalization of "nearest", "linear" and "cubic",
depending on the "order" argument: 0, 1 and 3.
* order: Controls the "spline" mode. See above.
* border: The boundary condition to assume for spline interpolation.
Ignored for Fourier-interpolation, which always assumes periodic
boundary conditions. Defaults to "constant", where areas outside
the map are assumed to have the constant value "cval".
* cval: See "border".
* force: Force interpolation, even when the input and output pixels
are directly compatible, so no interpolation is necessary. Normally
the faster enmap.extract is used in this case.
* safe: If True (default) make extra effort to resolve 2pi sky
wrapping degeneracies in the coordinate conversion.
* bsize: The interpolation is done in blocks in the y axis to save
memory. This argument controls how many rows are processed at once.
* context: How much to pad each y block by. Used to avoid ringing due
to discontinuities at block boundaries. Defaults to 50.
* ip: An interpolator object as returned by utils.interpolator(). If
provided, this interpolator is used directly, and the interpolation
arguments (mode, order, border, cval) are ignored. If the
interpolator does not count as "prefiltered" (meaning that each use of
the interpolator could incurr a potentially large cost regardless of
how few points are interpolated), then the whole map is processed in
one go, ignoring bsize"""
# Skip expensive operation if map is compatible
if not force:
if wcsutils.equal(map.wcs, wcs) and tuple(shape[-2:]) == tuple(shape[-2:]):
Expand All @@ -560,15 +600,28 @@ def project(map, shape, wcs, order=3, border="constant", cval=0.0, force=False,
omap = zeros(map.shape[:-2]+shape[-2:], wcs, map.dtype)
# Save memory by looping over rows. This won't work for non-"prefiltered" interpolators
if ip and not ip.prefiltered: bsize=100000000
# Avoid unneccessary padding for local cases
if ip or (mode == "spline" and order == 0): context = 0
elif mode == "spline" and order == 1 : context = 1
# It would have been nice to be able to use padtiles here, but
# the input and output tilings are very different
for i1 in range(0, shape[-2], bsize):
i2 = min(i1+bsize, shape[-2])
somap = omap[...,i1:i2,:]
pix = map.sky2pix(somap.posmap(), safe=safe)
y1 = max(np.min(pix[0]).astype(int)-3,0)
y2 = min(np.max(pix[0]).astype(int)+3,map.shape[-2])
if y2-y1 <= 0: continue
pix[0] -= y1
somap[:] = utils.interpol(map[...,y1:y2,:], pix, order=order, border=border, cval=cval, ip=ip)
if ip:
# Can't subdivide interpolator
band = map
else:
y1 = np.min(pix[0]).astype(int)-context
y2 = np.max(pix[0]).astype(int)+context+1
pix[0]-= y1
band = map.extract_pixbox([[y1,0],[y2,map.shape[-1]]])
# Apodize if necessary
if context > 1:
band = apod(band, width=(context,0), fill="crossfade")
# And do the interpolation
somap[:] = utils.interpol(band, pix, mode=mode, order=order, border=border, cval=cval, ip=ip)
return omap

def pixbox_of(iwcs,oshape,owcs):
Expand Down Expand Up @@ -721,9 +774,9 @@ def neighborhood_pixboxes(shape, wcs, poss, r):
res[...,1,:] += 1
return res

def at(map, pos, order=3, border="constant", cval=0.0, unit="coord", safe=True, ip=None):
def at(map, pos, mode="spline", order=3, border="constant", cval=0.0, unit="coord", safe=True, ip=None):
if unit != "pix": pos = sky2pix(map.shape, map.wcs, pos, safe=safe)
return utils.interpol(map, pos, order=order, border=border, cval=cval, ip=ip)
return utils.interpol(map, pos, mode=mode, order=order, border=border, cval=cval, ip=ip)

def argmax(map, unit="coord"):
"""Return the coordinates of the maximum value in the specified map.
Expand Down Expand Up @@ -1962,21 +2015,17 @@ def crop_geometry(shape, wcs, box=None, pixbox=None, oshape=None):
# pixel i would be included for i-0.5, but not for i-0.6. We should
# thefore use rounding boundaries, we just have to make sure it's
# numerically stable


print("box", box/utils.degree)
print("mid", np.mean(box,0)/utils.degree)
print("pixbox", pixbox)
pixbox = utils.nint(pixbox)
print("pixbox2", pixbox)


#print("box", box/utils.degree)
#print("mid", np.mean(box,0)/utils.degree)
#print("pixbox", pixbox)
#pixbox = utils.nint(pixbox)
#print("pixbox2", pixbox)
# Handle 1d case
if pixbox.ndim == 1:
if oshape is None: raise ValueError("crop_geometry needs an explicit output shape when given a 1d box (i.e. a single point instead of a bounding box")
shp = np.array(oshape[-2:])
pixbox = np.array([pixbox-shp//2,pixbox-shp//2+shp])
print("pixbox3", pixbox)
#print("pixbox3", pixbox)
# Can now proceed assuming 2d
oshape = tuple(shape[:-2]) + tuple(np.abs(pixbox[1]-pixbox[0]))
owcs = wcs.deepcopy()
Expand Down Expand Up @@ -2144,7 +2193,7 @@ def shrink_mask(mask, r):
"""Shrink the True part of boolean mask "mask" by a distance of r radians"""
return mask.distance_transform(rmax=r) >= r

def pad(emap, pix, return_slice=False, wrap=False):
def pad(emap, pix, return_slice=False, wrap=False, value=0):
"""Pad enmap "emap", creating a larger map with zeros filled in on the sides.
How much to pad is controlled via pix, which har format [{from,to},{y,x}],
[{y,x}] or just a single number to apply on all sides. E.g. pix=5 would pad
Expand All @@ -2159,7 +2208,7 @@ def pad(emap, pix, return_slice=False, wrap=False):
w = emap.wcs.deepcopy()
w.wcs.crpix += pix[0,::-1]
# Construct a slice between the new and old map
res = zeros(emap.shape[:-2]+tuple([s+sum(p) for s,p in zip(emap.shape[-2:],pix.T)]),wcs=w, dtype=emap.dtype)
res = full(emap.shape[:-2]+tuple([s+sum(p) for s,p in zip(emap.shape[-2:],pix.T)]),wcs=w, val=value, dtype=emap.dtype)
mslice = (Ellipsis,slice(pix[0,0],res.shape[-2]-pix[1,0]),slice(pix[0,1],res.shape[-1]-pix[1,1]))
res[mslice] = emap
if wrap:
Expand Down Expand Up @@ -2263,36 +2312,72 @@ def _widen(map,n):
def laplace(m):
return -ifft(fft(m)*np.sum(m.lmap()**2,0)).real

def apod(m, width, profile="cos", fill="zero"):
"""Apodize the provided map. Currently only cosine apodization is
implemented.
Args:
imap: (...,Ny,Nx) or (Ny,Nx) ndarray to be apodized
width: The width in pixels of the apodization on each edge.
profile: The shape of the apodization. Only "cos" is supported.
"""
width = np.minimum(np.zeros(2)+width,m.shape[-2:]).astype(np.int32)
if profile == "cos":
a = [0.5*(1-np.cos(np.linspace(0,np.pi,w))) for w in width]
else:
raise ValueError("Unknown apodization profile %s" % profile)
res = m.copy()
#def apod(m, width, profile="cos", fill="zero"):
# """Apodize the provided map. Currently only cosine apodization is
# implemented.
#
# Args:
# imap: (...,Ny,Nx) or (Ny,Nx) ndarray to be apodized
# width: The width in pixels of the apodization on each edge.
# profile: The shape of the apodization. Only "cos" is supported.
# """
# width = np.minimum(np.zeros(2)+width,m.shape[-2:]).astype(np.int32)
# if profile == "cos":
# a = [0.5*(1-np.cos(np.linspace(0,np.pi,w))) for w in width]
# else:
# raise ValueError("Unknown apodization profile %s" % profile)
# res = m.copy()
# if fill == "mean":
# offset = np.asarray(np.mean(res,(-2,-1)))[...,None,None]
# res -= offset
# elif fill == "median":
# offset = np.asarray(np.median(res,(-2,-1)))[...,None,None]
# res -= offset
# if width[0] > 0:
# res[...,:width[0],:] *= a[0][:,None]
# res[...,-width[0]:,:] *= a[0][::-1,None]
# if width[1] > 0:
# res[...,:,:width[1]] *= a[1][None,:]
# res[...,:,-width[1]:] *= a[1][None,::-1]
# if fill == "mean" or fill == "median":
# res += offset
# return res

def apod(map, width, profile="cos", fill="zero", inplace=False):
width = (np.zeros(2,int)+width).astype(int)
if not inplace: map = map.copy()
if fill == "mean":
offset = np.asarray(np.mean(res,(-2,-1)))[...,None,None]
res -= offset
offset = np.mean(map,(-2,-1))[...,None,None]
map -= offset
elif fill == "median":
offset = np.asarray(np.median(res,(-2,-1)))[...,None,None]
res -= offset
if width[0] > 0:
res[...,:width[0],:] *= a[0][:,None]
res[...,-width[0]:,:] *= a[0][::-1,None]
if width[1] > 0:
res[...,:,:width[1]] *= a[1][None,:]
res[...,:,-width[1]:] *= a[1][None,::-1]
if fill == "mean" or fill == "median":
res += offset
return res
offset = np.median(map,(-2,-1))[...,None,None]
map -= offset
# Process the axes one by one
for i, w in enumerate(width):
if w <= 0: continue
if fill == "crossfade":
x = np.arange(1,w+1,dtype=map.dtype)/(2*w+1)
else:
x = np.arange(1,w+1,dtype=map.dtype)/(w+1)
if profile == "lin": prof = apod_profile_lin(x)
elif profile == "cos": prof = apod_profile_cos(x)
else: raise ValueError("Unknown apodization profile '%s'" % str(profile))
# Apply the apodization
slice1 = (Ellipsis,)+(slice(None),)*i +(slice(0,w),)+(slice(None),)*(1-i)
slice2 = (Ellipsis,)+(slice(None),)*i +(slice(-w,None),)+(slice(None),)*(1-i)
broad = (None,)*i+(slice(None),)+(None,)*(1-i)
m1 = map[slice1].copy()
m2 = map[slice2].copy()
if fill == "crossfade":
map[slice1] = m1*(1-prof)[::-1][broad]+m2*prof[::-1][broad]
map[slice2] = m2*(1-prof)[broad]+m1*prof[broad]
elif fill in ["mean", "median", "zero"]:
map[slice1] *= prof[broad]
map[slice2] *= prof[::-1][broad]
# Add in offsets if necessary
if fill in ["mean", "median"]:
map += offset
return map

def apod_profile_lin(x): return x
def apod_profile_cos(x): return 0.5*(1-np.cos(np.pi*x))
Expand All @@ -2310,7 +2395,7 @@ def apod_mask(mask, width=1*utils.degree, edge=True, profile=apod_profile_cos):
r = mask.distance_transform(rmax=width)
return profile(r/width)

def lform(map):
def lform(map, method="auto"):
"""Given an enmap, return a new enmap that has been fftshifted (unless shift=False),
and which has had the wcs replaced by one describing fourier space. This is mostly
useful for plotting or writing 2d power spectra.
Expand All @@ -2319,12 +2404,12 @@ def lform(map):
are assumed to need conversion between degrees and radians, sky2pix etc. get confused
when applied to lform-maps."""
omap = fftshift(map)
omap.wcs = lwcs(map.shape, map.wcs)
omap.wcs = lwcs(map.shape, map.wcs, method=method)
return omap

def lwcs(shape, wcs):
def lwcs(shape, wcs, method="auto"):
"""Build world coordinate system for l-space"""
lres = 2*np.pi/extent(shape, wcs, signed=True)
lres = 2*np.pi/extent(shape, wcs, signed=True, method=method)
ny, nx = shape[-2:]
owcs = wcsutils.explicit(crpix=[nx//2+1,ny//2+1], crval=[0,0], cdelt=lres[::-1])
return owcs
Expand Down
Loading

0 comments on commit 3c79a01

Please sign in to comment.