From a40db5bc060f53a95e4e187d061bc8c21c9f756a Mon Sep 17 00:00:00 2001 From: amaurea Date: Fri, 8 Nov 2024 17:23:52 +0100 Subject: [PATCH] Update from development repo (#276) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fft.resample, np.product→np.prod and some build fixes * Added a small part to the makefile to make my own type of editable install more convenient. It's optional, and people who don't use it shouldn't be impacted. Also made the temporary linkables used as a dependency for the cython compile static instead of shared, since they are just an intermediate step towards the final shared library. That way we don't end up with too many similar but confusingly named shared libraries that depend on each other, e.g. lib_cmisc_shared.so being a dependency of the cmisc.so python extension * Faster wheels build * Oops, forgot to update wheel build directory * Back to original wheel builds * Started to port over work on wcs acceleration from failed branch * Make fejer the default geometry in fullsky_geometry and band_geometry. More general nditer. Disk overlap stuff. Nufft interface. * More general bunch write * Fixed missing nthread in fft.nufft * More sensible dtypes in ubash * Updated project and resample supporting non-equispaced fft interpolation * Fix broken utils.allgatherv * Work on new geometry stuff. Hasn't replaced the old interface. * New version * testing if things will work without the oldest versions * testing newer gcc * Trying -ld64 * back * Try to run on newer version of MacOS for Healpy compatibility * Run tests on arm instead of x64 * Just use macos-latest for all but pinned versions * Update checkout and setup-python? * bump version again, to be safe --------- Co-authored-by: Josh Borrow --- .github/workflows/build.yml | 27 ++- Makefile | 14 ++ cython/cmisc.pyx | 2 +- cython/cmisc_core.c | 52 +++++ cython/distances_core.c | 2 +- fortran/interpol.F90 | 2 +- meson.build | 6 +- pixell/aberration.py | 2 +- pixell/bunch.py | 118 ++++++---- pixell/colorize.py | 2 +- pixell/curvedsky.py | 2 +- pixell/enmap.py | 274 ++++++++++++++++++----- pixell/fft.py | 292 ++++++++++++++++++++++++- pixell/pointsrcs.py | 6 +- pixell/uharm.py | 2 +- pixell/utils.py | 420 +++++++++++++++++++++++++++++++----- pixell/wcsutils.py | 348 ++++++++++++++++++++++++------ pyproject.toml | 2 +- tests/test_pixell.py | 14 +- 19 files changed, 1333 insertions(+), 254 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0a787e60..d193b83d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,11 +9,11 @@ jobs: strategy: matrix: - python: ["3.12", "3.11", "3.10", "3.9"] + python: ["3.12", "3.11", "3.10"] steps: - - uses: actions/checkout@v1 - - uses: actions/setup-python@v1 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} @@ -30,19 +30,18 @@ jobs: test-macos: name: "Run tests on MacOS" - runs-on: macos-12 + runs-on: macos-latest env: - # LDFLAGS: "-ld64" # For MacOS 13 and above (XCode CLT 15 and above.) - CC: gcc-12 - CXX: gcc-12 - FC: gfortran-12 + CC: gcc-14 + CXX: gcc-14 + FC: gfortran-14 DUCC0_NUM_THREADS: 2 steps: - - uses: actions/checkout@v1 - - uses: actions/setup-python@v1 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.11" - name: Install Dependencies (MacOS) run: | @@ -91,7 +90,7 @@ jobs: strategy: matrix: # macos-13 is an intel runner, macos-14 is apple silicon - os: [macos-14] + os: [macos-latest] steps: - uses: actions/checkout@v4 @@ -113,9 +112,9 @@ jobs: name: Build source distribution runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - uses: actions/setup-python@v2 + - uses: actions/setup-python@v5 name: Install Python with: python-version: '3.10' diff --git a/Makefile b/Makefile index ed243fb7..625785a5 100644 --- a/Makefile +++ b/Makefile @@ -76,3 +76,17 @@ dist: clean ## builds source and wheel package install: clean ## install the package to the active Python's site-packages pip install . + + +# Symlink-based alternative setup below. Not intended for general use +# Standard meson build +SHELL=/bin/bash +.PHONY: build inline +inline: build + (shopt -s nullglob; cd pixell; rm -f *.so; ln -s ../build/*.so ../build/*.dylib .) +build: build/build.ninja + (cd build; meson compile) +build/build.ninja: + rm -rf build + mkdir build + meson setup build diff --git a/cython/cmisc.pyx b/cython/cmisc.pyx index 0e7b3141..6b5ea81a 100644 --- a/cython/cmisc.pyx +++ b/cython/cmisc.pyx @@ -31,7 +31,7 @@ def alm2cl(ainfo, alm, alm2=None): # I used to flatten here to make looping simple, but that caused a copy to be made # when combined with np.broadcast. So instead I will use manual raveling pshape = alm.shape[:-1] - npre = int(np.product(pshape)) + npre = int(np.prod(pshape)) cdef float[::1] cl_single_sp, alm_single_sp1, alm_single_sp2 cdef double[::1] cl_single_dp, alm_single_dp1, alm_single_dp2 cdef int64_t[::1] mstart = np.ascontiguousarray(ainfo.mstart).view(np.int64) diff --git a/cython/cmisc_core.c b/cython/cmisc_core.c index c73a0a1f..ded2d9ae 100644 --- a/cython/cmisc_core.c +++ b/cython/cmisc_core.c @@ -3,6 +3,7 @@ #include #include #include +#include #include int min(int a, int b) { return a < b ? a : b; } @@ -269,3 +270,54 @@ void transfer_alm_sp(int lmax1, int mmax1, int64_t * mstart1, float * alm1, int } } } + +// wcs acceleration +#define DEG (M_PI/180) + +// I pass the wcs information as individual doubles to avoid having to construct +// numpy arrays on the python side. All of these have the arguments in the same +// order, regardless of which way they go, so they can be defined in a macro +// +// We only implement plain spherical coordinates here - the final coordinate +// rotation is missing. For cylindrical coordinates this means that we only +// support dec0 = 0 - the result will be wrong for other values +#define wcsdef(name) \ +void name(int64_t n, double * restrict dec, double * restrict ra, \ + double * restrict y, double * restrict x, \ + double crval0, double crval1, double cdelt0, double cdelt1, \ + double crpix0, double crpix1) { \ + double ra0 = crval0*DEG, dec0 = crval1*DEG; \ + double dra = cdelt0*DEG, ddec = cdelt1*DEG; \ + double x0 = crpix0-1, y0 = crpix1-1; \ + _Pragma("omp parallel for") \ + for(int64_t i = 0; i < n; i++) { +#define wcsend } \ +} + +wcsdef(wcs_car_sky2pix) +x[i] = (ra [i]-ra0 )/dra +x0; +y[i] = (dec[i]-dec0)/ddec+y0; // dec0 should be zero +wcsend + +wcsdef(wcs_car_pix2sky) +ra [i] = (x[i]-x0)*dra +ra0; +dec[i] = (y[i]-y0)*ddec+dec0; // dec0 should be zero +wcsend + +wcsdef(wcs_cea_sky2pix) +x[i] = (ra [i]-ra0 )/dra +x0; +y[i] = sin(dec[i])/ddec +y0; +(void)dec0; // mark dec0 as explicitly unused +wcsend + +wcsdef(wcs_cea_pix2sky) +ra [i] = (x[i]-x0)*dra +ra0; +dec[i] = asin((y[i]-y0)*ddec); +(void)dec0; // mark dec0 as explicitly unused +wcsend + +void rewind_inplace(int64_t n, double * vals, double period, double ref) { + _Pragma("omp parallel for") + for(int64_t i = 0; i < n; i++) + vals[i] = fmod((vals[i]+ref),period)-ref; +} diff --git a/cython/distances_core.c b/cython/distances_core.c index 9fc14cea..276a5fca 100644 --- a/cython/distances_core.c +++ b/cython/distances_core.c @@ -15,7 +15,7 @@ int xoffs[8] = {-1, +1, 0, 0, +1, -1, +1, -1 }; double wall_time() { struct timeval tv; gettimeofday(&tv,0); return tv.tv_sec + 1e-6*tv.tv_usec; } int max(int a, int b) { return a > b ? a : b; } int min(int a, int b) { return a < b ? a : b; } -int compar_int(int * a, int * b) { return *a-*b; } +int compar_int(const void * a, const void * b) { return *(int*)a-*(int*)b; } int wrap1(int a, int n) { return a < 0 ? a+n : a >= n ? a-n : a; } // The simple functions are too slow to serve as the basis for a distance transform. diff --git a/fortran/interpol.F90 b/fortran/interpol.F90 index a39ae0fc..92c2da88 100644 --- a/fortran/interpol.F90 +++ b/fortran/interpol.F90 @@ -1,6 +1,6 @@ module fortran - private :: map_border, calc_weights + !private :: map_border, calc_weights contains diff --git a/meson.build b/meson.build index faefcdcb..8ff773bf 100644 --- a/meson.build +++ b/meson.build @@ -45,6 +45,8 @@ incdir_f2py = run_command( run_command('make', '-C', 'fortran', check: true) fortran_include = include_directories(incdir_numpy, incdir_f2py) +add_project_arguments('-Wno-tabs', language : 'fortran') +add_project_arguments('-Wno-conversion', language : 'fortran') fortran_sources = { 'fortran/interpol_32.f90': '_interpol_32', @@ -87,7 +89,7 @@ helper_sources = { linkables = [] foreach source_name, module_name : helper_sources - linkables += shared_library( + linkables += static_library( module_name, source_name, install: true, @@ -120,4 +122,4 @@ endforeach # The actual python install itself is left up to a helper build # script deifned in pixell/ subdir('pixell') -subdir('scripts') \ No newline at end of file +subdir('scripts') diff --git a/pixell/aberration.py b/pixell/aberration.py index 210e64e4..771aca90 100644 --- a/pixell/aberration.py +++ b/pixell/aberration.py @@ -116,7 +116,7 @@ def __init__(self, shape, wcs, dir=dir_equ, beta=beta, spin=[0,2], nthread = int(utils.fallback(utils.getenv("OMP_NUM_THREADS",nthread),0)) # 1. Calculate the aberration field. These are tiny alm_dpos = calc_boost_field(-beta, dir, nthread=nthread) - # 2. Evaluate these on our target geometry. Hardcoded float64 because of get_deflected_angles + # 2. Evaluate these on our target geometry. deflect = enmap.zeros(alm_dpos.shape[:-1]+shape[-2:], wcs, coord_dtype) curvedsky.alm2map(alm_dpos.astype(coord_ctype, copy=False), deflect, spin=1, nthread=nthread) # 3. Calculate the offset angles. diff --git a/pixell/bunch.py b/pixell/bunch.py index ae2c0e6f..ef490d83 100644 --- a/pixell/bunch.py +++ b/pixell/bunch.py @@ -1,5 +1,6 @@ """My own version of bunch, since the standard one lacks tab completion and has trouble printing sometimes.""" +import os class Bunch: def __init__(self, *args, **kwargs): self._dict = {} @@ -53,61 +54,102 @@ def __repr__(self): # Some simple I/O routines. These can't handle everything that could # be in a bunch, but they cover all my most common use cases. -def read(fname, fmt="auto", group=None): - if fmt == "auto": - if is_hdf_path(fname): fmt = "hdf" - else: raise ValueError("Could not infer format for '%s'" % fname) - if fmt == "hdf": return read_hdf(fname, group=group) +def read(fname, fmt="auto", group=None, gmode="dot"): + if fmt == "auto": fmt="hdf" + if fmt == "hdf": return read_hdf(fname, group=group, gmode=gmode) else: raise ValueError("Unrecognized format '%s'" % fmt) -def write(fname, bunch, fmt="auto", group=None): - if fmt == "auto": - if is_hdf_path(fname): fmt = "hdf" - else: raise ValueError("Could not infer format for '%s'" % fname) - if fmt == "hdf": write_hdf(fname, bunch, group=group) +def write(fname, bunch, fmt="auto", group=None, gmode="dot"): + if fmt == "auto": fmt = "hdf" + if fmt == "hdf": write_hdf(fname, bunch, group=group, gmode=gmode) else: raise ValueError("Unrecognized format '%s'" % fmt) -def write_hdf(fname, bunch, group=None): +def read_hdf(fname, group=None, gmode="dot"): import h5py - fname, group = split_hdf_path(fname, group) + if group is None: + fname, group = split_hdf_path(fname, group, mode=gmode) + with h5py.File(fname, "r") as hfile: + if group: hfile = hfile[group] + return read_hdf_recursive(hfile) + +def read_hdf_recursive(hfile): + import h5py + if isinstance(hfile, h5py.Dataset): + return hfile[()] + else: + bunch = Bunch() + for key in hfile: + bunch[key] = read_hdf_recursive(hfile[key]) + return bunch + +def write_hdf(fname, bunch, group=None, gmode="dot"): + import h5py + if group is None: + fname, group = split_hdf_path(fname, group, mode=gmode) with h5py.File(fname, "w") as hfile: if group: hfile = hfile.create_group(group) - for key in bunch: + write_hdf_recursive(hfile, bunch) + +def write_hdf_recursive(hfile, bunch): + for key in bunch: + if isinstance(bunch[key],Bunch): + hfile.create_group(key) + write_hdf_recursive(hfile[key], bunch[key]) + else: hfile[key] = bunch[key] -def read_hdf(fname, group=None): - import h5py - bunch = Bunch() - fname, group = split_hdf_path(fname, group) - with h5py.File(fname, "r") as hfile: - if group: hfile = hfile[group] - for key in hfile: - bunch[key] = hfile[key][()] - return bunch +#def make_safe(val): +# import numpy as np +# if isinstance(val, np.ndarray): +# try: return np.char.encode(val) +# except TypeError: pass +# elif isinstance(val, str): +# return val.encode() +# else: +# return val def is_hdf_path(fname): """Returns true if the fname would be recognized by split_hdf_path""" - for suf in [".hdf", ".h5"]: - name, _, group = fname.rpartition(suf) - if name and (not group or group[0] == "/"): return True - return False + return True -def split_hdf_path(fname, subgroup=None): +def split_hdf_path(fname, subgroup=None, mode="dot"): """Split an hdf path of the form path.hdf/group, where the group part is optional, into the path and the group parts. If subgroup is specified, then it will be appended to the group informaiton. returns fname, group. The fname will be a string, and the group will be a string or None. Raises - a ValueError if the fname is not recognized as a hdf file.""" - for suf in [".hdf", ".h5"]: - name, _, group = fname.rpartition(suf) - if not name: continue - name += suf - if not group: return name, subgroup - elif group[0] == "/": - group = group[1:] - if subgroup: group += "/" + subgroup - return name, group - raise ValueError("Not an hdf path") + a ValueError if unsuccessful. + + mode controles how the split is done: + * "none": Don't split. fname is returned unmodified + * "dot": The last entry in the path given by filename + containing a "." will be taken to be the real + file name, the rest till be the hdf group path. + For example, with a/b/c.d/e/f, a/b/c.d would be returned + as the file name and e/f as the hdf group + * "exists": As dot, but based on whether a file with that + name can be found on disk. Seemed like a good idea, + except it doesn't work when writing a new file. + """ + toks = fname.split("/") + if mode == "dot": + # Find last entry with a dot i in it + for i, tok in reversed(list(enumerate(toks))): + if "." in tok: break + else: raise ValueError("Could not split hdf path using 'dot' method: no . found") + elif mode == "exists": + for i in reversed(list(range(len(toks)))): + cand = "/".join(toks[:i+1]) + if os.path.isfile(cand): break + else: raise ValueError("Could not split hdf path using 'exists' method: no file found") + elif mode == "none": + i = len(toks) + else: raise ValueError("Unrecognized split mode '%s'" % (str(mode))) + # Return the result + fname = "/".join(toks[:i+1]) + gtoks = toks[i+1:] + if subgroup: gtoks.append(subgroup) + group = "/".join(gtoks) if len(gtoks)>0 else None + return fname, group def concatenate(bunches): """Go from a list of bunches to a bunch of lists.""" diff --git a/pixell/colorize.py b/pixell/colorize.py index 8f56c57a..e987e599 100644 --- a/pixell/colorize.py +++ b/pixell/colorize.py @@ -169,7 +169,7 @@ def mpl_register(names=None): if isinstance(names, basestring): names = [names] for name in names: cmap = to_mpl_colormap(name, schemes[name]) - matplotlib.cm.register_cmap(name, cmap) + matplotlib.colormaps.register(cmap) def mpl_setdefault(name): import matplotlib.pyplot diff --git a/pixell/curvedsky.py b/pixell/curvedsky.py index 4d3d074b..b6d886db 100644 --- a/pixell/curvedsky.py +++ b/pixell/curvedsky.py @@ -1238,7 +1238,7 @@ def analyse_geometry(shape, wcs, tol=1e-6): # TODO: Pseudo-cylindrical projections can be handled with standard ducc synthesis, # so ideally our check would be less stringent than this. Supporinting them requires # more work, so will just do it with the general interface for now. - separable = wcsutils.is_cyl(wcs) + separable = wcsutils.is_separable(wcs) divides = utils.hasoff(360/np.abs(wcs.wcs.cdelt[0]), 0, tol=tol) if not separable or not divides: # Not cylindrical or ra does not evenly divide the sky diff --git a/pixell/enmap.py b/pixell/enmap.py index 22fe0b54..49d0513e 100644 --- a/pixell/enmap.py +++ b/pixell/enmap.py @@ -80,8 +80,8 @@ def lmap(self, oversample=1): return lmap(self.shape, self.wcs, oversample=overs def lform(self): return lform(self) def modlmap(self, oversample=1, min=0): return modlmap(self.shape, self.wcs, oversample=oversample, min=min) def modrmap(self, ref="center", safe=True, corner=False): return modrmap(self.shape, self.wcs, ref=ref, safe=safe, corner=corner) - def lbin(self, bsize=None, brel=1.0, return_nhit=False, return_bins=False): return lbin(self, bsize=bsize, brel=brel, return_nhit=return_nhit, return_bins=return_bins) - def rbin(self, center=[0,0], bsize=None, brel=1.0, return_nhit=False, return_bins=False): return rbin(self, center=center, bsize=bsize, brel=brel, return_nhit=return_nhit, return_bins=return_bins) + def lbin(self, bsize=None, brel=1.0, return_nhit=False, return_bins=False, lop=None): return lbin(self, bsize=bsize, brel=brel, return_nhit=return_nhit, return_bins=return_bins, lop=lop) + def rbin(self, center=[0,0], bsize=None, brel=1.0, return_nhit=False, return_bins=False, rop=None): return rbin(self, center=center, bsize=bsize, brel=brel, return_nhit=return_nhit, return_bins=return_bins, rop=rop) def area(self): return area(self.shape, self.wcs) def pixsize(self): return pixsize(self.shape, self.wcs) def pixshape(self, signed=False): return pixshape(self.shape, self.wcs, signed=signed) @@ -98,13 +98,13 @@ def preflat(self): def npix(self): return np.prod(self.shape[-2:]) @property def geometry(self): return self.shape, self.wcs - def resample(self, oshape, off=(0,0), method="fft", mode="wrap", corner=False, order=3): return resample(self, oshape, off=off, method=method, mode=mode, corner=corner, order=order) - def project(self, shape, wcs, order=3, mode="constant", cval=0, prefilter=True, mask_nan=False, safe=True): return project(self, shape, wcs, order, mode=mode, cval=cval, prefilter=prefilter, mask_nan=mask_nan, safe=safe) + def resample(self, oshape, off=(0,0), method="fft", border="wrap", corner=False, order=3): return resample(self, oshape, off=off, method=method, border=border, corner=corner, order=order) + def project(self, shape, wcs, order=3, border="constant", cval=0, safe=True): return project(self, shape, wcs, order, border=border, cval=cval, safe=safe) def extract(self, shape, wcs, omap=None, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None, reverse=False): return extract(self, shape, wcs, omap=omap, wrap=wrap, op=op, cval=cval, iwcs=iwcs, reverse=reverse) def extract_pixbox(self, pixbox, omap=None, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None, reverse=False): return extract_pixbox(self, pixbox, omap=omap, wrap=wrap, op=op, cval=cval, iwcs=iwcs, reverse=reverse) def insert(self, imap, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None): return insert(self, imap, wrap=wrap, op=op, cval=cval, iwcs=iwcs) def insert_at(self, pix, imap, wrap="auto", op=lambda a,b:b, cval=0, iwcs=None): return insert_at(self, pix, imap, wrap=wrap, op=op, cval=cval, iwcs=iwcs) - def at(self, pos, order=3, mode="constant", cval=0.0, unit="coord", prefilter=True, mask_nan=False, safe=True): return at(self, pos, order, mode=mode, cval=0, unit=unit, prefilter=prefilter, mask_nan=mask_nan, safe=safe) + def at(self, pos, order=3, border="constant", cval=0.0, unit="coord", safe=True): return at(self, pos, order, border=border, cval=0, unit=unit, safe=safe) def argmax(self, unit="coord"): return argmax(self, unit=unit) def autocrop(self, method="plain", value="auto", margin=0, factors=None, return_info=False): return autocrop(self, method, value, margin, factors, return_info) def apod(self, width, profile="cos", fill="zero"): return apod(self, width, profile=profile, fill=fill) @@ -192,7 +192,7 @@ def helper(b): if xflip: omap = omap[...,:,::-1] return omap -def subinds(shape, wcs, box, mode=None, cap=True, noflip=False): +def subinds(shape, wcs, box, mode=None, cap=True, noflip=False, epsilon=1e-4): """Helper function for submap. Translates the coordinate box provided into a pixel units. @@ -204,18 +204,22 @@ def subinds(shape, wcs, box, mode=None, cap=True, noflip=False): inclusive and exclusive modes break this, and should be used with caution. 2. tiny floating point errors should not usually be able to cause the ibox to change. Most boxes will have some simple fraction of - a whole degree, and most have pixels with centers at a simple fraction - of a whole degree. Hence, it is likely that box edges will fall - almost exactly on an integer pixel value. floor and ceil will - then move us around by a whole pixel based on tiny numerical - jitter around this value. Hence these should be used with caution. - These concerns leave us with mode = "round" as the only generally - safe alternative, which is why it's default. + a whole degree, and most have pixels with centers or pixel edges + at a simple fraction of a whole degree. mode="floor" or "ceil" + break when pixel centers are at whole values. mode="round" + breaks when pixel edges are at whole values. But since small + (but not float-precision-size) offsets from these cases are unlikely, + we can define safe rounding by adding an epsilon to the values + before rounding. As long as this epsilon is use consistently, + box overlap still works. + With epsilon in place, modes "round", "floor" and "ceil" are all safe. + We make "round" the default. """ if mode is None: mode = "round" box = np.asarray(box) # Translate the box to pixels bpix = skybox2pixbox(shape, wcs, box, include_direction=True) + bpix[:2] += epsilon if noflip: for b in bpix.T: if b[2] < 0: b[:] = [b[1],b[0],-b[2]] @@ -226,6 +230,8 @@ def subinds(shape, wcs, box, mode=None, cap=True, noflip=False): elif mode == "exclusive": bpix = [np.ceil (bpix[0]),np.floor(bpix[1]), bpix[2]] else: raise ValueError("Unrecognized mode '%s' in subinds" % str(mode)) bpix = np.array(bpix, int) + # A pixel goes from [i1-0.5:i2+0.5] with round(+eps) this becomes [i1:i2+1] + # We therefore don't need to add 1 to get a proper slice if cap: # Make sure we stay inside our map bounds for b, n in zip(bpix.T,shape[-2:]): @@ -430,7 +436,7 @@ def posmap(shape, wcs, safe=True, corner=False, separable="auto", dtype=np.float is 1000x faster. """ res = zeros((2,)+tuple(shape[-2:]), wcs, dtype) - if separable == "auto": separable = wcsutils.is_cyl(wcs) + if separable == "auto": separable = wcsutils.is_separable(wcs) if separable: # If posmap could return a (dec,ra) tuple instead of an ndmap, # we could have returned np.broadcast_arrays(dec, ra) instead. @@ -540,20 +546,20 @@ def contains(shape, wcs, pos, unit="coord"): else: pix = pos return np.all((pix>=0)&(pix.T 1: imap = imap.copy() - if order > 1: - imap = utils.interpol_prefilter(imap, order=order, inplace=True) + ip = utils.interpolator(imap, order=order) if omap is None: # Generate an output map if not nside: @@ -2282,7 +2452,7 @@ def to_healpix(imap, omap=None, nside=0, order=3, chunk=100000, destroy_input=Fa pos = np.array(healpy.pix2ang(nside, np.arange(i, min(npix,i+chunk)))) # Healpix uses polar angle, not dec pos[0] = np.pi/2 - pos[0] - omap[...,i:i+chunk] = imap.at(pos, order=order, mask_nan=False, prefilter=False) + omap[...,i:i+chunk] = imap.at(pos, ip=ip) return omap def to_flipper(imap, omap=None, unpack=True): @@ -2805,7 +2975,7 @@ def ifftshift(map, inplace=False): def fillbad(map, val=0, inplace=False): return np.nan_to_num(map, copy=not inplace, nan=val, posinf=val, neginf=val) -def resample(map, oshape, off=(0,0), method="fft", mode="wrap", corner=False, order=3): +def resample(map, oshape, off=(0,0), method="fft", border="wrap", corner=False, order=3): """Resample the input map such that it covers the same area of the sky with a different number of pixels given by oshape.""" # Construct the output shape and wcs @@ -2819,7 +2989,7 @@ def resample(map, oshape, off=(0,0), method="fft", mode="wrap", corner=False, or off -= 0.5 - 0.5*np.array(oshape[-2:],float)/map.shape[-2:] # in output units opix = pixmap(oshape) - off[:,None,None] ipix = opix * (np.array(map.shape[-2:],float)/oshape[-2:])[:,None,None] - omap = ndmap(map.at(ipix, unit="pix", mode=mode, order=order), owcs) + omap = ndmap(map.at(ipix, unit="pix", border=border, order=order), owcs) else: raise ValueError("Invalid resample method '%s'" % method) return omap diff --git a/pixell/fft.py b/pixell/fft.py index 2bd67f80..88c69ad2 100644 --- a/pixell/fft.py +++ b/pixell/fft.py @@ -138,7 +138,8 @@ def fft(tod, ft=None, nthread=0, axes=[-1], flags=None, _direction="FFTW_FORWARD use in the fft. The default (0) uses the value specified by the OMP_NUM_THREAD environment varible if that is specified, or the total number of cores on the computer otherwise.""" - tod = asfcarray(tod) + tod = asfcarray(tod) + axes = utils.astuple(-1 if axes is None else axes) if tod.size == 0: return nt = nthread or nthread_fft if flags is None: flags = default_flags @@ -164,7 +165,8 @@ def ifft(ft, tod=None, nthread=0, normalize=False, axes=[-1],flags=None, engine= meaning that fft followed by ifft will multiply the data by the length of the transform. By specifying the normalize argument, you can turn normalization on, though the normalization step will not use paralellization.""" - ft = asfcarray(ft) + ft = asfcarray(ft) + axes = utils.astuple(-1 if axes is None else axes) if ft.size == 0: return nt = nthread or nthread_ifft if flags is None: flags = default_flags @@ -184,7 +186,8 @@ def ifft(ft, tod=None, nthread=0, normalize=False, axes=[-1],flags=None, engine= def rfft(tod, ft=None, nthread=0, axes=[-1], flags=None, engine="auto"): """Equivalent to fft, except that if ft is not passed, it is allocated with appropriate shape and data type for a real-to-complex transform.""" - tod = asfcarray(tod) + tod = asfcarray(tod) + axes = utils.astuple(-1 if axes is None else axes) if ft is None: oshape = list(tod.shape) oshape[axes[-1]] = oshape[axes[-1]]//2+1 @@ -198,7 +201,8 @@ def irfft(ft, tod=None, n=None, nthread=0, normalize=False, axes=[-1], flags=Non is specified, that is used as the length of the last transform axis of the output array. Otherwise, the length of this axis is computed assuming an even original array.""" - ft = asfcarray(ft) + ft = asfcarray(ft) + axes = utils.astuple(-1 if axes is None else axes) if tod is None: oshape = list(ft.shape) oshape[axes[-1]] = n or (oshape[axes[-1]]-1)*2 @@ -221,8 +225,9 @@ def dct(tod, dt=None, nthread=0, normalize=False, axes=[-1], flags=None, type="D Note that DCTs and DSTs were only added to pyfftw in version 13.0. The function will fail with an Invalid scheme error for older versions. """ - tod = asfcarray(tod) - type= _dct_names[type] + tod = asfcarray(tod) + type = _dct_names[type] + axes = utils.astuple(-1 if axes is None else axes) if dt is None: dt = empty(tod.shape, tod.dtype) return fft(tod, dt, nthread=nthread, axes=axes, flags=flags, _direction=[type]*len(axes), engine=engine) @@ -256,6 +261,7 @@ def idct(dt, tod=None, nthread=0, normalize=False, axes=[-1], flags=None, type=" dt = asfcarray(dt) type = _dct_inverses[_dct_names[type]] off = _dct_sizes[type] + axes = utils.astuple(-1 if axes is None else axes) if tod is None: tod = empty(dt.shape, dt.dtype) fft(dt, tod, nthread=nthread, axes=axes, flags=flags, _direction=[type]*len(axes), engine=engine) @@ -341,6 +347,7 @@ def shift(a, shift, axes=None, nofft=False, deriv=None, engine="auto"): ca = a+0j shift = np.atleast_1d(shift) if axes is None: axes = range(-len(shift),0) + axes = utils.astuple(axes) fa = fft(ca, axes=axes, engine=engine) if not nofft else ca for i, ax in enumerate(axes): ax %= ca.ndim @@ -353,6 +360,24 @@ def shift(a, shift, axes=None, nofft=False, deriv=None, engine="auto"): else: ca = fa return ca if np.iscomplexobj(a) else ca.real +def resample(a, n, axes=None, nthread=0, engine="auto"): + """Given an array a, resize the given axes (defaulting to the last ones) to + length n (tuple or int) using Fourier resampling. For example, if a has shape + (2,3,4), then resample(a, 10, -1) has shape (2,3,10), and resample(a, (20,10), (0,2)) + has shape (20,3,10).""" + a = np.asarray(a) + n = utils.astuple(n) + if axes is None: + axes = [-len(n)+i for i in range(len(n))] + if len(n) != len(axes): + raise ValueError("Resize size n = %s does not match axes = %s" % (str(n),str(axes))) + fa = fft(a, axes=axes, nthread=nthread, engine=engine) + norm = 1/np.prod([a.shape[ax] for ax in axes]) + fa = resample_fft(fa, n, axes=axes, norm=norm) + out = ifft(fa, axes=axes, normalize=False, nthread=nthread, engine=engine) + if not np.iscomplexobj(a): out = out.real + return out + def resample_fft(fa, n, out=None, axes=-1, norm=1, op=lambda a,b:b): """Given array fa[{dims}] which is the fourier transform of some array a, transform it so that that it corresponds to the fourier transform of @@ -372,9 +397,8 @@ def resample_fft(fa, n, out=None, axes=-1, norm=1, op=lambda a,b:b): fa = np.asanyarray(fa) # Support n and axes being either tuples or a single number, # and broadcast n to match axes - try: axes = tuple(axes) - except TypeError: axes = (axes,) - n = np.zeros(len(axes),int)+n + axes = utils.astuple(axes) + n = np.zeros(len(axes),int)+n # Determine the shape of the output array oshape = list(fa.shape) for i, ax in enumerate(axes): @@ -401,6 +425,183 @@ def transfer(dest, source, norm, op): transfer(out[sel], fa[sel], norm, op) return out +def interpol_nufft(a, inds, out=None, axes=None, normalize=True, + periodicity=None, epsilon=None, nthread=None, nofft=False): + """Given some array a[{pre},{dims}] interpolate it at the given + inds[len(dims),{post}], resulting in an output with shape [{pre},{post}]. + The signal is assumed to be periodic with the size of a unless this is overridden + with the periodicity argument, which should have an integer for each axis being + transformed. Normally the last ndim = len(inds) axes of a are interpolated. + This can be overridden with the axes argument. + + By default the interpolation is properly normalized. This can be turned off + with the normalization argument, in which case the output will be too high + by a factor of np.prod([a.shape[ax] for ax in axes]). If all axes are used, + this simplifies to a.size""" + # This function could be implemented as simply u2nu(fft(a),inds). The problem + # with this is that a full fourier-array needs to be allocated. I can save + # some memory by instead doing the fft per field, at the cost of it being + # a bit hacky + op = None if nofft else lambda a, h: fft.fft(a, nthread=h.nthread, axes=h.axes) + return u2nu(a, inds, out=out, axes=axes, periodicity=periodicity, + epsilon=epsilon, nthread=nthread, normalize=normalize, complex=False, op=op) + +def u2nu(fa, inds, out=None, axes=None, periodicity=None, epsilon=None, nthread=None, + normalize=False, forward=False, complex=True, op=None): + """Given complex fourier coefficients fa[{pre},{dims}] corresponding to + some real-space array a, evaluate the real-space signal at the given + inds[len(dims),{post}], resulting in a output with shape [{pre},{post}]. + + Arguments: + * fa: Array of equi-spaced fourier coefficients. Complex with shape [{pre},{dims}] + * inds: Array of positions at which to evaluate the inverse Fourier transform + of fa. Real with shape [len(dims),{post}] + * out: Array to write result to. Real or complex with shape [{pre},{post}]. + Optional. Allocated if missing. + * axes: Tuple of axes to perform the transform along. len(axes)=len(dims). + Optional. Defaults to the last len(dims) axes. + * periodicity: Periodicity assumed in the Fourier transform. Tuple with length + len(dims). Defaults to the shape of the axes being transformed. + * epsilon: The target relative accuracy of the non-uniform FFT. Defaults + to 1e-5 for single precision and 1e-12 for double precision. See the + ducc0.nufft documentation for details. + * normalize: If True, the output is divided by prod([fa.shape[ax] for ax in axes]), + that is, the total number of elements in the transform. This normalization is + equivalent to that of ifft. Defaults to False. + * forward: Controls the sign of the exponent in the Fourier transform. By default + a backwards transform (fourier to real) is performed. By passing forward=True, + you can instead regard fa as a real-sapce array and out as a non-equispaced + Fourier array. + * complex: Only relevant if out=None. Controls whether out is allocated as a + real or complex array. Defaults to complex. + """ + h = _nufft_helper(fa, out, inds, axes=axes, nuout=True, periodicity=periodicity, + epsilon=epsilon, nthread=nthread, normalize=normalize, complex=complex) + if op is None: op = lambda fa, h: fa + for uI, nuI in zip(h.uiter, h.nuiter): + grid = op(h.u[uI],h).astype(h.ctype, copy=False) + res = ducc0.nufft.u2nu(grid=grid, coord=h.iflat, forward=forward, + epsilon=h.epsilon, nthreads=h.nthread, periodicity=h.periodicity, + fft_order=True) + if not np.iscomplexobj(h.nu): + res = res.real + h.nu[nuI] = res.reshape(h.inds.shape[1:]) + if h.normalize: + h.nu /= h.norm + return h.nu + +# FIXME: Check normalization +def nu2u(a, inds, out=None, oshape=None, axes=None, periodicity=None, epsilon=None, nthread=None, + normalize=False, forward=False): + h = _nufft_helper(out, a, inds, axes=axes, nuout=False, periodicity=periodicity, ushape=oshape, + epsilon=epsilon, nthread=nthread, normalize=normalize, complex=complex) + work = np.zeros(h.tshape, h.ctype) + for uI, nuI in zip(h.uiter, h.nuiter): + res = ducc0.nufft.nu2u(points=h.nu[nuI], coord=h.iflat, out=work, + forward=forward, epsilon=h.epsilon, nthreads=h.nthread, + periodicity=h.periodicity, fft_order=True) + if not np.iscomplexobj(h.u): + res = res.real + h.u[uI] = res + if h.normalize: + h.u /= h.norm + return h.u + +def iu2nu(a, inds, out=None, oshape=None, axes=None, periodicity=None, epsilon=None, nthread=None, + normalize=False, forward=False): + """The inverse of nufft/u2nu. Given non-equispaced samples a[{pre},{post}] and + their coordinates inds[len(dims),{post}], calculates the equispaced + Fourier coefficients out[{pre},{dims}] of a. + + Arguments: + * a: Array of of non-equispaced values. Real or complex with shape [{pre},{post}] + * inds: Coordinates of samples in a. Real with shape [len(dims),{post}]. + * out: Equispaced Fourier coefficients of a. Complex with shape [{pre},{dims}]. + Optional, but if missing, the shape of the out array to allocate must be + specified using the oshape argument + * oshape: Tuple giving the shape to use when allocating out (if it's not passed in). + See u2nu for the meaning of the other arguments. + """ + h = _nufft_helper(out, a, inds, axes=axes, nuout=False, ushape=oshape, + periodicity=periodicity, epsilon=epsilon, nthread=nthread, + normalize=normalize, complex=complex) + work = np.zeros(h.tshape, h.ctype) + def wzip(u): return u.reshape(-1).view(h.rtype) + def wunzip(x): return x.view(h.ctype).reshape(h.tshape) + def P(u): return ducc0.nufft.u2nu(grid=u, coord=h.iflat, forward=forward, + epsilon=h.epsilon, nthreads=h.nthread, periodicity=h.periodicity, fft_order=True) + def PT(nu): + return ducc0.nufft.nu2u(points=nu, coord=h.iflat, + out=work, forward=not forward, + epsilon=h.epsilon, nthreads=h.nthread, periodicity=h.periodicity, fft_order=True) + for uI, nuI in zip(h.uiter, h.nuiter): + # Invert u2nu by finding the least-squares solution to + # a = u2nu(out, inds). Written linearly this is a = P out + # with solution out = (P'P)"P'a. The CG solver wants real numbers, though, + # so we hack around that with view + # Set up the equation system. Our degrees of freedom are flattened real u + b = wzip(PT(h.nu[nuI].reshape(-1))) + def A(x): return wzip(PT(P(wunzip(x)))) + solver = utils.CG(A, b) + while solver.err > h.epsilon: + solver.step() + res = wunzip(solver.x) + if not np.iscomplexobj(h.u): + res = res.real + h.u[uI] = res + if h.normalize: + h.u *= h.norm + return h.u + +# FIXME: Check normalization +def inu2u(fa, inds, out=None, axes=None, periodicity=None, epsilon=None, nthread=None, + normalize=False, forward=False, complex=True): + h = _nufft_helper(fa, out, inds, axes=axes, nuout=True, + periodicity=periodicity, epsilon=epsilon, nthread=nthread, + normalize=normalize, complex=complex) + work = np.zeros(h.tshape, h.ctype) + def wzip(nu): return nu.view(h.rtype) + def wunzip(x): return x.view(h.ctype) + def P(nu): return ducc0.nufft.nu2u(points=nu, coord=h.iflat, out=work, forward=forward, + epsilon=h.epsilon, nthreads=h.nthread, periodicity=h.periodicity, fft_order=True) + def PT(u): return ducc0.nufft.u2nu(grid=u, coord=h.iflat, forward=not forward, + epsilon=h.epsilon, nthreads=h.nthread, periodicity=h.periodicity, fft_order=True) + for uI, nuI in zip(h.uiter, h.nuiter): + # Invert nu2u by finding the least-squares solution to + # fa = nu2u(out, inds). Written linearly this is fa = P out + # with solution out = (P'P)"P'fa + b = wzip(PT(h.u[uI])) + def A(x): return wzip(PT(P(wunzip(x)))) + solver = utils.CG(A, b) + while solver.err > h.epsilon: + solver.step() + res = wunzip(solver.x) + if not np.iscomplexobj(h.nu): + res = res.real + h.nu[nuI] = res.reshape(h.inds.shape[1:]) + if h.normalize: + h.nu *= h.norm + return h.nu + +# Alternative nufft interface more in line with fft and curvedsky. +# TODO: Add proper docstrings here. Can I avoid lots of repetition? + +def nufft(a, inds, out=None, oshape=None, axes=None, periodicity=None, epsilon=None, nthread=None, normalize=False, flip=False): + return iu2nu(a, inds, out=out, oshape=oshape, axes=axes, periodicity=periodicity, epsilon=epsilon, nthread=nthread, normalize=normalize, forward=flip) + +def inufft(fa, inds, out=None, axes=None, periodicity=None, epsilon=None, nthread=None, normalize=False, flip=False, complex=True, op=None): + return u2nu(fa, inds, out=out, axes=axes, periodicity=periodicity, epsilon=epsilon, nthread=nthread, normalize=normalize, forward=flip, complex=complex, op=op) + +def nufft_adjoint(a, inds, out=None, oshape=None, axes=None, periodicity=None, epsilon=None, nthread=None, normalize=False, flip=False): + return nu2u(a, inds, out=out, oshape=oshape, axes=axes, periodicity=periodicity, epsilon=epsilon, nthread=nthread, normalize=normalize, forward=not flip) + +def inufft_adjoint(fa, inds, out=None, axes=None, periodicity=None, epsilon=None, nthread=None, normalize=False, flip=False, complex=True): + return inu2u(fa, inds, out=out, axes=axes, periodicity=periodicity, epsilon=epsilon, nthread=nthread, normalize=normalize, forward=not flip) + + +########### Helper functions ############## + + def fft_flat(tod, ft, nthread=1, axes=[-1], flags=None, _direction="FFTW_FORWARD"): """Workaround for intel FFTW wrapper. Flattens appropriate dimensions of intput and output arrays to avoid crash that otherwise happens for arrays with @@ -430,3 +631,76 @@ def ifft_flat(ft, tod, nthread=1, axes=[-1], flags=None): tod = utils.partial_expand(tod, shape_tod, axes=axes, pos=0) return tod +def _nufft_helper(u, nu, inds, axes=None, periodicity=None, epsilon=None, + nuout=False, nthread=None, complex=True, normalize=False, ushape=None): + """Do the type checking etc. needed to prepare for our nufft operations. + This is a lot of code, but the overhead is around 300 µs, plus any time + needed to allocate the output array. So there's a bit of overhead, but + not anything we can't live with, and the raw ducc interface is available + for when this overhead is too much.""" + from . import bunch + # Prepare arguments for nufft operations. Must ensure that + # * inds → iflat[ndim,npoint] f32 or f64 + # * ctype = c64 or c128 based u or nu, priority to which one is output which must be specified + u = np.asarray(u) if u is not None else None + nu = np.asarray(nu) if nu is not None else None + inds = np.asarray(inds) + # Are we single or double precision? This set of statements sets up a priority + # order for which array to get the dtype from + dtypes = [nu if nuout else u, u if not nuout else nu, inds, np.float64()] + rtypes = [utils.real_dtype(d.dtype) for d in dtypes if d is not None] + rtype = [d for d in rtypes if d in [np.float32, np.float64]][0] + ctype = utils.complex_dtype(rtype) + if ctype not in [np.complex64, np.complex128]: + raise ValueError("only single and double precision supported") + # Convert inds to the right dtype only if it has an invalid dtype + if inds.dtype not in [np.float32,np.float64]: + inds = inds.astype(rtype, copy=False) + ndim = inds.shape[0] + # By default the last ndim dimensions are transformed + if axes is None: axes = tuple(range(-ndim,0)) + axes = utils.astuple(axes) + if len(axes) != ndim: raise ValueError("Number of axes to transform does not match len(inds)!") + # Set up output array. This depends on which direction we're going + odtype = ctype if complex else rtype + if nuout: + npre = u.ndim-ndim + if npre < 0: + raise ValueError("uniform array must has at least as many dimensions as indexed by the first axis of inds!") + pshape = utils.without_inds(u.shape, axes) + if nu is None: + # Output array. Allocating it like this lets it inherit any subclass of + # inds, which is useful when interpolating an enmap with another enmap + nu = np.zeros_like(inds, shape=pshape+inds.shape[1:], dtype=odtype) + if nu.shape != pshape+inds.shape[1:]: + raise ValueError("nu must have shape pshape+inds.shape[1:]") + else: + if u is None: + if ushape is None: raise ValueError("Either the output uniformly sampled array or its shape must be provided") + u = np.zeros(ushape, dtype=odtype) + npre = u.ndim-ndim + pshape = utils.without_inds(u.shape, axes) + # Hard to do any more sanity checks here + tshape = utils.only_inds(u.shape, axes) + npoint = np.prod(tshape) + # Periodicity of the full space. Allows us to support arrays that represent + # a subset of a bigger, periodic array + if periodicity is None: periodicity = tshape + else: periodicity = np.zeros(ndim,int)+periodicity + nthread = nthread or nthread_fft + # Target accuracy + if epsilon is None: + epsilon = 1e-5 if ctype == np.complex64 else 1e-12 + # ducc wants just a single pre-dimension for inds, so flatten it. + iflat = inds.reshape(ndim,-1).T + # Do the actual looping + other_axes = tuple(utils.complement_inds(axes, u.ndim)) + axall = tuple(range(ndim)) + norm = np.prod([u.shape[ax] for ax in axes]) + uiter = utils.nditer(u.shape, axes=other_axes) + nuiter = utils.nditer(nu.shape[:npre]) + return bunch.Bunch(u=u, nu=nu, inds=inds, iflat=iflat, + epsilon=epsilon, nthread=nthread, normalize=normalize, norm=norm, + periodicity=periodicity, pshape=pshape, tshape=tshape, npoint=npoint, + other_axes=other_axes, axall=axall, complex=complex, rtype=rtype, + ctype=ctype, npre=npre, uiter=uiter, nuiter=nuiter) diff --git a/pixell/pointsrcs.py b/pixell/pointsrcs.py index f4697fa9..c039927b 100644 --- a/pixell/pointsrcs.py +++ b/pixell/pointsrcs.py @@ -81,7 +81,7 @@ def sim_objects(shape, wcs, poss, amps, profile, prof_ids=None, omap=None, vmin= sources will have been added (or maxed etc. depending on op) into the map. Otherwise, the only signal in the map will be the objects.""" dtype = np.float32 # C extension only supports this dtype - if separable == "auto": separable = wcsutils.is_cyl(wcs) + if separable == "auto": separable = wcsutils.is_separable(wcs) # Object positions obj_decs = np.asanyarray(poss[0], dtype=dtype, order="C") obj_ras = np.asanyarray(poss[1], dtype=dtype, order="C") @@ -149,7 +149,7 @@ def radial_sum(map, poss, bins, oprofs=None, separable="auto", Returns the resulting profiles. If oprof was specified, then the same object will be returned (after being updated of course).""" dtype = np.float32 # C extension only supports this dtype - if separable == "auto": separable = wcsutils.is_cyl(map.wcs) + if separable == "auto": separable = wcsutils.is_separable(map.wcs) # Object positions obj_decs = np.asanyarray(poss[0], dtype=dtype, order="C") obj_ras = np.asanyarray(poss[1], dtype=dtype, order="C") @@ -264,7 +264,7 @@ def sim_srcs_python(shape, wcs, srcs, beam, omap=None, dtype=None, nsigma=5, rma The source simulation is sped up by using a source lookup grid. """ - if separable == "auto": separable = wcsutils.is_cyl(wcs) + if separable == "auto": separable = wcsutils.is_separable(wcs) if omap is None: omap = enmap.zeros(shape, wcs, dtype) ishape = omap.shape omap = omap.preflat diff --git a/pixell/uharm.py b/pixell/uharm.py index ea90c005..285dcd7e 100644 --- a/pixell/uharm.py +++ b/pixell/uharm.py @@ -138,7 +138,7 @@ def hprof2rprof(self, harm, r): return curvedsky.harm2profile(harm, r) def lprof2hprof(self, lprof): if self.mode == "flat": - return enmap.enmap(utils.interpol(lprof, self.l[None], order=1, mode="constant"), self.wcs, copy=False) + return enmap.enmap(utils.interpol(lprof, self.l[None], order=1, border="constant"), self.wcs, copy=False) else: if lprof.shape[-1] >= self.lmax+1: return lprof[...,:self.lmax+1] diff --git a/pixell/utils.py b/pixell/utils.py index 9386bb7f..fd72f2d5 100644 --- a/pixell/utils.py +++ b/pixell/utils.py @@ -1,4 +1,4 @@ -import numpy as np, scipy.ndimage, os, errno, scipy.optimize, time, datetime, warnings, re, sys +import numpy as np, scipy.ndimage, os, errno, scipy.optimize, time, datetime, warnings, re, sys, scipy.special try: xrange except: xrange = range try: basestring @@ -193,8 +193,9 @@ def inverse_order(order): def complement_inds(inds, n): """Given a subset of range(0,n), return the missing values. E.g. complement_inds([0,2,4],7) => [1,3,5,6]""" + if inds is None: inds = np.arange(n) mask = np.ones(n, bool) - mask[inds] = False + mask[np.array(inds)] = False return np.where(mask)[0] def unmask(arr, mask, axis=0, fill=0): @@ -306,6 +307,12 @@ def deslope(d, w=1, inplace=False, axis=-1, avg=np.mean): di -= np.arange(di.size)*(avg(di[-w:])-avg(di[:w]))/(di.size-1)+avg(di[:w]) return d +def argmax(arr): + """Multidimensional argmax. Returns a tuple indexing the full array + instead of just a number indexing the flattened array like np.argmax does""" + arr = np.asanyarray(arr) + return np.unravel_index(np.argmax(arr), arr.shape) + def ctime2mjd(ctime): """Converts from unix time to modified julian date.""" return np.asarray(ctime)/86400. + 40587.0 @@ -515,44 +522,129 @@ def dedup(a): The original is not modified.""" return a[np.concatenate([[True],a[1:]!=a[:-1]])] -def interpol(a, inds, order=3, mode="nearest", mask_nan=False, cval=0.0, prefilter=True): - """Given an array a[{x},{y}] and a list of float indices into a, - inds[len(y),{z}], returns interpolated values at these positions as [{x},{z}].""" - a = np.asanyarray(a) +def interpol(arr, inds, out=None, mode="spline", border="nearest", + order=3, cval=0.0, epsilon=None, ip=None): + """Given an array arr[{x},{y}] and a list of float indices into a, + inds[len(y),{z}], returns interpolated values at these positions as [{x},{z}]. + + The mode and order arguments control the interpolation type. These can be: + * mode=="nn" or (mode=="spline" and order==0): Nearest neighbor interpolation + * mode=="lin" or (mode=="spline" and order==1): Linear interpolation + * mode=="cub" or (mode=="spline" and order==3): Cubic interpolation + * mode=="fourier": Non-uniform fourier interpolation + + The border argument controls the boundary condition. This does not apply + for fourier interpolation, which always assumes periodic boundary. + Valid values are: + * "nearest": Indices outside the array use the value from the nearest + point on the edge. + * "cyclic": Periodic boundary conditions + * "mirrored": Mirrored boundary conditions + * "constant": Use a constant value, given by the cval argument + + Epsilon controls the target relative accuracy of the interpolation. + Only applies to fourier interpolation. Spline interpolation is + overall much less accurate (assuming a band-limited true signal), + and its accuracy can't be controlled, but roughly corresponds to 1e-3. + Defaults to 1e-6 for single precision and 1e-15 for double precision + arrays. + + Compatibility notes: + * mask_nan is no longer supported. You must implement this yourself + if you need it. Do this something like + mask = ~np.isfinite(arr) + out = interpol(arr, inds, ...) + omask= interpol(mask, inds, mode="nn") + out[omask!=0] = np.nan + * prefilter is no longer supported. This argument let the interpolation + skip a heavy prefiltering step if the array was already filtered. + This was useful, but assumed that the precomputed array was the same + shape and data type as the array to be implemented, which is not the + case for fourier interpolation. This functionality was replaced by + interpolator objects returned by utils.interpolator, which are what's + used to implement this function. + """ + arr = np.asanyarray(arr) inds = np.asanyarray(inds) - inds_orig_nd = inds.ndim - if inds.ndim == 1: inds = inds[:,None] - - npre = a.ndim - inds.shape[0] - res = np.empty(a.shape[:npre]+inds.shape[1:],dtype=a.dtype) - fa, fr = partial_flatten(a, range(npre,a.ndim)), partial_flatten(res, range(npre, res.ndim)) - if mask_nan: - mask = ~np.isfinite(fa) - fa[mask] = 0 - for i in range(fa.shape[0]): - fr[i].real = scipy.ndimage.map_coordinates(fa[i].real, inds, order=order, mode=mode, cval=cval, prefilter=prefilter) - if np.iscomplexobj(fa[i]): - fr[i].imag = scipy.ndimage.map_coordinates(fa[i].imag, inds, order=order, mode=mode, cval=cval, prefilter=prefilter) - if mask_nan and np.sum(mask) > 0: - fmask = np.empty(fr.shape,dtype=bool) - for i in range(mask.shape[0]): - fmask[i] = scipy.ndimage.map_coordinates(mask[i], inds, order=0, mode=mode, cval=cval, prefilter=prefilter) - fr[fmask] = np.nan - if inds_orig_nd == 1: res = res[...,0] - return res - -def interpol_prefilter(a, npre=None, order=3, inplace=False, mode="nearest"): - if order < 2: return a - a = np.asanyarray(a) - if not inplace: a = a.copy() - if npre is None: npre = max(0,a.ndim - 2) - if npre < 0: npre = a.ndim-npre - # spline_filter was looping through the enmap pixel by pixel with getitem. - # Not using flatview got around it, but I don't understand why it happend - # in the first place. - for I in nditer(a.shape[:npre]): - a[I] = scipy.ndimage.spline_filter(a[I], order=order, mode=mode) - return a + npre = arr.ndim - len(inds) + if ip is None: + ip = interpolator(arr, npre, mode=mode, border=border, order=order, + cval=cval, epsilon=epsilon) + return ip(inds, out=out) + +def interpolator(arr, npre=0, mode="spline", border="nearest", order=3, cval=0.0, + epsilon=None): + """Construct an interpolator object that can be used to quickly interpolate + many positions in some array arr. Wrapper for the underlying SplineInterpolator + and FourierInterpolator classes. Used to implement the interpolate function. + See it for argument details.""" + mode, order = _ip_get_mode(mode, order) + if mode == "spline": + return SplineInterpolator(arr, npre=npre, mode=mode, border=border, + order=order, cval=cval) + elif mode == "fourier": + return FourierInterpolator(arr, npre=npre, epsilon=epsilon) + else: + raise ValueError("Unrecognized interpolation mode '%s'" % str(mode)) + +class SplineInterpolator: + prefiltered = True + def __init__(self, arr, npre=0, mode="spline", border="nearest", order=3, cval=0.0): + self.mode, self.order = _ip_get_mode(mode, order) + self.npre = npre % arr.ndim + self.cval = cval + self.border = border + if self.mode != "spline": raise ValueError("Unrecognized spline interpolation mode '%s'" % str(mode)) + arr = np.asanyarray(arr) + if self.order > 1: + arr = arr.copy() + for I in nditer(arr.shape[:npre]): + arr[I] = scipy.ndimage.spline_filter(arr[I], order=self.order, mode=self.border) + self.arr = arr + def __call__(self, inds, out=None): + inds, out = _ip_prepare(self, inds, out=out) + # Do the actual interpolation + for I in nditer(self.arr.shape[:self.npre]): + out[I] = scipy.ndimage.map_coordinates(self.arr[I], inds, order=self.order, + mode=self.border, cval=self.cval, prefilter=False) + return out + +class FourierInterpolator: + prefiltered = False + def __init__(self, arr, npre=0, epsilon=None): + from . import fft + self.npre = npre % arr.ndim + self.arr = np.asanyarray(arr) + self.epsilon = epsilon + self.farr = fft.fft(arr, axes=tuple(range(self.npre,arr.ndim))) + def __call__(self, inds, out=None): + from . import fft + inds, out = _ip_prepare(self, inds, out=out) + out = fft.interpol_nufft(self.farr, inds, out=out, nofft=True, epsilon=self.epsilon) + return out + +def _ip_get_mode(mode, order): + # The type of interpolation to do + if mode in ["nn", "nearest"]: mode, order = "spline", 0 + elif mode in ["lin","linear" ]: mode, order = "spline", 1 + elif mode in ["cub","cubic" ]: mode, order = "spline", 3 + elif mode in ["fft","nufft","fourier"]: mode = "fourier" + if mode not in ["spline", "fourier"]: raise ValueError("Unrecognized interpol mode '%s'" % str(mode)) + return mode, order + +def _ip_prepare(self, inds, out=None): + inds = np.asanyarray(inds) + ndim = inds.ndim + if self.arr.ndim-len(inds) != self.npre: + raise ValueError("arr.ndim-len(inds) != npre") + # Allow us to use ndim<2 inputs, e.g. interpol(np.arange(6),3) instead of + # interpol(np.arange(6),[[3]]) + while inds.ndim < 2: inds = inds[...,None] + if out is None: + # Doing it this way lets interpol inherit the array subclass from inds, which + # is useful when interpolating one enmap with another enmap + out = np.zeros_like(inds, shape=self.arr.shape[:self.npre]+inds.shape[1:], dtype=self.arr.dtype) + return inds, out def interp(x, xp, fp, left=None, right=None, period=None): """Unlike utils.interpol, this is a simple wrapper around np.interp that extends it @@ -1275,7 +1367,7 @@ def allgather(a, comm): rather than needing an output argument.""" a = np.asarray(a) res = np.zeros((comm.size,)+a.shape,dtype=a.dtype) - if np.issubdtype(a.dtype, np.string_): + if np.issubdtype(a.dtype, np.bytes_): comm.Allgather(a.view(dtype=np.uint8), res.view(dtype=np.uint8)) else: comm.Allgather(a, res) @@ -1310,10 +1402,11 @@ def allgatherv(a, comm, axis=0): #print(comm.rank, "fa.shape", fa.shape) ra = fa.reshape(fa.shape[0],-1) if fa.size > 0 else fa.reshape(0,np.prod(fa.shape[1:],dtype=int)) N = ra.shape[1] - n = allgather([len(ra)],comm) + # Number of elements each task has + n = allgather([len(ra)],comm).reshape(-1) o = cumsum(n) rb = np.zeros((np.sum(n),N),dtype=ra.dtype) - #print("A", comm.rank, ra.shape, ra.dtype, rb.shape, rb.dtype, n, N) + # print("A", comm.rank, ra.shape, ra.dtype, rb.shape, rb.dtype, n, N) comm.Allgatherv(ra, (rb, (n*N,o*N))) fb = rb.reshape((rb.shape[0],)+fa.shape[1:]) # Restore original data type @@ -2646,15 +2739,15 @@ def uvec(n, i, dtype=np.float64): u[i] = 1 return u -def ubash(Afun, n, dtype=np.float64): +def ubash(Afun, n, idtype=np.float64, odtype=None): """Find the matrix representation Amat of linear operator Afun by repeatedly applying it unit vectors with length n.""" - v = Afun(uvec(n,0,dtype=dtype)) + v = Afun(uvec(n,0,dtype=idtype)) m = len(v) - Amat = np.zeros((m,n), dtype=dtype) + Amat = np.zeros((m,n), dtype=odtype or v.dtype) Amat[:,0] = v for i in range(1,n): - Amat[:,i] = Afun(uvec(n,i,dtype=dtype)) + Amat[:,i] = Afun(uvec(n,i,dtype=idtype)) return Amat def load_ascii_table(fname, desc, sep=None, dsep=None): @@ -2891,6 +2984,72 @@ def profile_to_tform_hankel(profile_fun, lmin=0.1, lmax=1e7, n=512, pad=256): lprof = rht.real2harm(profile_fun) return rht.unpad(rht.l, lprof) +class FFTLog: + def __init__(self, xrange=None, krange=None, n=512, pad=0, bias=0): + """Set up an FFTLog, a Fast Fourier Transform for log-spaced data. + Implemented using the Fast Hankel Transform in scipy.fft.fht. + Define the domain by passing in either xrange=[xmin,xmax] or krange=[kmin,kmax], + but not both. The other will be defined as the inverse of the one given. + + The number of sample points is given by n. + + If pad is given, then the domain will be expanded with this number of + points on both sides. These can later be chopped off with the unpad + method. + + bias affects the implied boundary conditions. The standard FFTLog + has bias=0, the default, but a differnt bias can allow exact results + for power laws. See https://jila.colorado.edu/~ajsh/FFTLog""" + if xrange is None and krange is None: raise ValueError("Either xrange xor krange must be given") + if xrange is not None and krange is not None: raise ValueError("Either xrange xor krange must be given") + if xrange is None: xrange = krange[::-1] + self.step = (np.log(xrange[1])-np.log(xrange[0]))/(n-1) + self.pad = pad + self.n = n + # Define our positions + self.x = np.exp(np.linspace(np.log(xrange[0])-self.step*pad, np.log(xrange[1])+self.step*pad, n+2*pad)) + self.k = 1/self.x[::-1] + self.xh = self.x**(0.5-bias) + self.kh = self.k**(0.5+bias) + # Pre-multiply the normalization into kh. This takes care of all + # the normalization except for a factor 2 in the inverse + self.kh /= (np.pi/2)**0.5 + self.bias = bias + def fft(self, a): + """Perform a forward fft along the last axis of a, which must be sampled + at the points self.x""" + import scipy.fft + # Allow us to pass a function to evaluate at the given coordinates + try: a = a(self.x) + except TypeError: pass + xa = a*self.xh + cos = scipy.fft.fht(xa, self.step, -0.5, bias=self.bias)/self.kh + sin = scipy.fft.fht(xa, self.step, +0.5, bias=self.bias)/self.kh + del xa + # Minus sign comes from the negative exponent in the forward fft + return cos-1j*sin + def ifft(self, fa): + """Perform an inverse fft along the last axis of a, which must be sampled + at the points self.k""" + import scipy.fft + # Allow us to pass a function to evaluate at the given coordinates + try: fa = fa(self.k) + except TypeError: pass + kfa = fa*(self.kh/2) + a = scipy.fft.ifht( kfa.real, self.step, -0.5, bias=self.bias)/self.xh + a += scipy.fft.ifht(-kfa.imag, self.step, +0.5, bias=self.bias)/self.xh + return a + def unpad(self, *arrs): + """Remove the padding from arrays used by this object. The + values in the padded areas of the output of the transform have + unreliable values, but they're not cropped automatically to + allow for round-trip transforms. Example: + r = unpad(r_padded) + r, l, vals = unpad(r_padded, l_padded, vals_padded)""" + if self.pad == 0: res = arrs + else: res = tuple([arr[...,self.pad:arr.shape[-1]-self.pad] for arr in arrs]) + return res[0] if len(arrs) == 1 else res + def fix_dtype_mpi4py(dtype): """Work around mpi4py bug, where it refuses to accept dtypes with endian info""" return np.dtype(np.dtype(dtype).char) @@ -3091,6 +3250,10 @@ def ascomplex(arr): arr = np.asanyarray(arr) return arr.astype(complex_dtype(arr.dtype)) +def astuple(num_or_list): + try: return tuple(num_or_list) + except TypeError: return (num_or_list,) + # Conjugate gradients def default_M(x): return np.copy(x) @@ -3105,7 +3268,7 @@ class CG: degrees of freedom, where each mpi task only stores parts of the full solution. It is also reentrant, meaning that one can do nested CG if necessary. """ - def __init__(self, A, b, x0=None, M=default_M, dot=default_dot): + def __init__(self, A, b, x0=None, M=default_M, dot=default_dot, destroy_b=False): """Initialize a solver for the system Ax=b, with a starting guess of x0 (0 if not provided). Vectors b and x0 must provide addition and multiplication, as well as the .copy() method, such as provided by numpy arrays. The @@ -3118,8 +3281,8 @@ def __init__(self, A, b, x0=None, M=default_M, dot=default_dot): self.M = M self.dot = dot if x0 is None: - self.x = b*0 - self.r = b + self.x = np.zeros_like(b) + self.r = b.copy() if not destroy_b else b else: self.x = x0.copy() self.r = b-self.A(self.x) @@ -3215,18 +3378,41 @@ def step(self): # Estimate of variance of Ax-b self.abserr = self.rz/len(self.x) -def nditer(shape): +def nditer(shape, axes=None): + """Iterate over all multidimensional indices into an array with the given shape. + If axes is specified, then it should be a list of the axes in shape to iterate + over. The remaining axes will not be indexed (the yielded multi-index will have + slice(None) for those axes). The order the entries in axes does not matter.""" ndim = len(shape) - I = [0]*ndim + axes = tuple(range(ndim)) if axes is None else tuple(sorted([ax%ndim for ax in axes])) + axes = axes[::-1] # will iterate backwards below + I = [slice(None)]*ndim + for ax in axes: I[ax] = 0 while True: yield tuple(I) - for dim in range(ndim-1,-1,-1): - I[dim] += 1 - if I[dim] < shape[dim]: break - I[dim] = 0 + for ax in axes: + I[ax] += 1 + if I[ax] < shape[ax]: break + I[ax] = 0 else: break +def without_inds(a, inds): + """Return a as a tuple with the given inds removed. Not optimized for + long arrays""" + if inds is None: return a + inds = astuple(inds) + # Negative inds + inds = [(n+len(a) if n<0 else n) for n in inds] + return tuple([v for i,v in enumerate(a) if i not in inds]) + +def only_inds(a, inds): + """Return a as a tuple with only the given inds present. Not optimized for + long arrays""" + if inds is None: return () + inds = astuple(inds) + return tuple([a[i] for i in inds]) + def first_importable(*args): """Given a list of module names, return the name of the first one that can be imported.""" @@ -3279,7 +3465,7 @@ def primes(n): i += 1 else: n //= i - factors.append(i) + factors.append(i) if n > 1: factors.append(n) return factors @@ -3327,6 +3513,24 @@ def setenv(name, value, keep=False): elif name in os.environ and value is None: del os.environ[name] elif value is not None: os.environ[name] = str(value) +def getaddr(a): + """Get the address of the start of a""" + return a.__array_interface__["data"][0] + +def iscontig(a, naxes=None): + """Return whether array a is C-contiguous. If naxes is specified, + then only the last naxes axes need to be contiguous, and axes + before that are ignored.""" + if naxes is None: naxes = a.ndim + naxes = min(a.ndim, naxes) + expected = a.itemsize + for i in range(naxes): + j = a.ndim-1-i + if a.strides[j] != expected: + return False + expected *= a.shape[j] + return True + def zip2(*args): """Variant of python's zip that calls next() the same number of times on all arguments. This means that it doesn't give up immediately after getting @@ -3372,3 +3576,107 @@ def distpow(dist, N): dist = np.convolve(dist,dist) N >>= 1 return res + +def airy(x): + """Dimensionless real-space representation of Airy beam, normalized to peak at 1. + To get the airy beam an angular distance r from the center for a telescope with + aperture diameter D at wavelength λ, use airy(sin(r)*(2*pi*D/λ)). + """ + # Avoid division by zero at low radius + with nowarn(): + return np.where(x<1e-6, 1-x**2, (scipy.special.j1(2*x)/x)**2) + +def lairy(x): + """This is the harmonic space representation of an Airy beam. + To get the airy beam at multipole l for a telescope with aperture + diameter D at wavelength λ, call lairy(l/(2*pi*D/λ)). Valid as long as + the beam is small compared to the curvature of the sky. + + Multiply the result by airy_area(D,λ) if you want the harmonic space representation + of an Airy beam with a real-space peak of one. + """ + x = np.clip(x,0,1) + return (np.arccos(x)-x*(1-x**2)**0.5)/(np.pi/2) + +def airy_area(D, λ): + """Area (steradians) of airy beam for an aperture of size D and wavelength λ. + This is simply (2λ/D)²/π""" + return (2*λ/D)**2/np.pi + +def disk_overlap(d, R): + """Area of overlap between two disks with radius R and distance d between + their centers.""" + x = np.clip(d/(2*R),0,1) + return (np.arccos(x)-x*(1-x**2)**0.5)*(2*R**2) + +def disk_overlap_curved(d, R, tol_flat=1e-4, tol_tiny=1e-10): + """Solid angle of overlap between two disks with radius R and distance d + between their centers, on the sphere. I thought this would be useful for + calculating the curved-sky equivalent for the airy beam, but it seems it + won't. Oh well, it was hard to calculate, so here it is anyway. + + The actual curved-sky airy beam would start from + + airy(r) = int_-R^R dx √(R²-x²) exp(2πiux/λ) + + where u = cos(θ) and θ is the angle from the center of the beam. + This should hold up to an angle of π/2 away from the center. After + that the aperture is mostly obscured, and a new expression will + be needed, if it's not zero. + + I think this is actually what I've implemented in airy(x) above, as + long as one uses sin when calling it. + """ + d, R = np.broadcast_arrays(d, R) + null = (d >= 2*R)|(R==0) + flat = (R < tol_flat) & ~null + tiny = (d < tol_tiny) & ~null + main = ~flat & ~tiny & ~null + res = np.zeros_like(d) + res[flat] = disk_overlap(d[flat],R[flat]) + res[tiny] = _disk_overlap_curved_tiny(d[tiny],R[tiny]) + res[main] = _disk_overlap_curved_main(d[main],R[main]) + return res + +def _disk_overlap_curved_main(d, R): + sinR, cosR = np.sin(R), np.cos(R) + return 2*np.arccos((1-np.cos(d))/sinR**2-1)-4*cosR*np.arccos(cosR/sinR*np.tan(d/2)) + +def _disk_overlap_curved_tiny(d, R): + """Curved sky disk overlap in limit of tiny separations. + First order accuracy in d""" + return 2*np.pi*(1-np.cos(R)) - 4*np.sin(R)*np.sin(d/2) + +# Hm, the first bin can be exceptional, so use the 2nd instead. +# Would be easier if we could # demand that dl_2 = dl_3. Would +# then remove b_2 from the equation system, and say +# b_2 = b_3-(b_4-b_3) = 2b_3-b_4. This would give +# l_1 = 0.5*(b_1+b_2) = 0.5*(b_1+2b_3-b_4) +# l_2 = 0.5*(b_2+b_3) = 0.5*(3b_3-b_4) +# +# l = 0.5*[1 2 -1 0 ...] +# [0 3 -1 0 ...] +# [0 1 1 0 ...] +# [............] +def infer_bin_edges(l): + """Given bin centers l[n], returns bin edges b[n+1] such + that l = 0.5*(b[1:]+b[:-1]) under the assumption + b[2] = (b[1]+b[3])/2. This is equivalent to assuming that + the 2nd and 3rd bins have the same size. The problem is + underspecified, so an assumption like this is needed, but + it could be generalized exactly what it is. + """ + from scipy import sparse + n = len(l) + P = 0.5*sparse.csr_array( + ( + np.concatenate([[1,2,-1,3,-1],np.ones(2*(n-2))]), + ( + np.concatenate([[0,0,0,1,1],np.arange(2,n),np.arange(2,n)]), + np.concatenate([[0,1,2,1,2],np.arange(1,n-1),np.arange(2,n)]), + ) + ), shape=(n,n) + ) + b = sparse.linalg.spsolve(P.T.dot(P), P.T.dot(l)) + b = np.concatenate([b[:1],[2*b[1]-b[2]],b[1:]]) + return b diff --git a/pixell/wcsutils.py b/pixell/wcsutils.py index e8aece42..5ee3f80f 100644 --- a/pixell/wcsutils.py +++ b/pixell/wcsutils.py @@ -19,36 +19,81 @@ except: basestring = str def streq(x, s): return isinstance(x, basestring) and x == s -# The origin argument used in the wcs pix<->world routines seems to -# have to be 1 rather than the 0 one would expect. For example, -# if wcs is CAR(crval=(0,0),crpix=(0,0),cdelt=(1,1)), then -# pix2world(0,0,1) is (0,0) while pix2world(0,0,0) is (-1,-1). +# Geometry construction redesign # -# No! the problem is that everythin in the fits header counts from 1, -# so the default crpix should be (1,1), not (0,0). With -# CAR(crval(0,0),crpix(1,1),cdelt(1,1)) we get -# pix2world(1,1,1) = (0,0) and pix2world(0,0,0) = (0,0) - -# Useful stuff to be able to do: -# * Create a wcs from (point,res) -# * Create a wcs from (box,res) -# * Create a wcs from (box,shape) -# * Create a wcs from (point,res,shape) -# Can support this by taking arguments: -# pos: point[2] or box[2,2], mandatory -# res: num or [2], optional -# shape: [2], optional -# In cases where shape is not specified, the implied -# shape can be recovered from the wcs and a box by computing -# the pixel coordinates of the corners. So we don't need to return -# it. - -# 1. Construct wcs from box, res (and return shape?) -# 2. Construct wcs from box, shape -# 3. Construct wcs from point, res (this is the most primitive version) - -deg2rad = np.pi/180 -rad2deg = 1/deg2rad +# The old approach was build around the reference point. The idea was that +# this point would always be a pixel center, no matter which coutout of +# the sky one was looking at. The problem with this approach is that it doesn't +# generalize to downgrading, and it clashes with finer detalies such as +# distinguishing between CC and Fejer1, which care about the pixel alignment +# at the poles, not the equator where the reference point usually is. +# +# The new approach will proceed in three steps: +# 1. Specify the projection (ctype, crval) without any pixel details +# 2. Turn it into a full-sky pixelization (crpix, cdelt). This could +# be done by specifying ny,nx or the resolution. We could here +# issue a warning or exception if the sky isn't evenly tiled. +# This part would care about sub-specifiers like :cc or :fejer1 +# 3. Crop this to cover the target area +# +# These can all be handled in separate functions. The output from step +# 1 would be a wcs with default crpix and cdelt values. +# +# Problems: +# 1. Currently pixelization can't handle all of these: +# * Fix left side but allow right side to float +# * Fix right side but allow left side to float +# * Fix both sides, but allow total width to float +# Right now the first two are supported, but not the last. +# For example, for CEA we can't expect a reasonable resolution +# to reach the poles with a senible pixel offset, but we want to +# at least make things symmetric around the equator. There's a choice +# between trying to get a pixel edge as close to the poles as possible +# or trying to get a pixel center as close to the poles as possible. +# These could be written as hh adjust and 00 adjust, but hard to fit +# this in currently. +# 2. Some projections have extra parameters, like CEA with lambda and +# ZEA where one might want to be locally conformal. The approach where +# one first builds the fullsky geometry and only later worried about +# restricting it to a part of the sky clashes with this. Can support it, +# but is it a good idea to do it automatically? + +def projection(system, crval=None): + """Generate a pixelization-agnostic wcs""" + system = system.lower() + if crval is None: crval = default_crval(system) + else: crval = np.zeros(2)+crval + if system in ["", "plain"]: return explicit(crval=crval) + return explicit(ctype=["RA---"+system.upper(), "DEC--"+system.upper()], crval=crval) + +def pixelization(pwcs, shape=None, res=None, variant=None): + """Add pixel information to a wcs, returning a full-sky geometry, + or as close to that as the projection allows.""" + # This is the hard part. Many projections have invalid areas, and + # some have infinite size. May just have to handle the cases one by + # one instead of trying to be general + system = get_proj(pwcs) + extent = default_extent(system) + variant = variant or default_variant(system) + offs = parse_variant(variant) + periodic = is_periodic(system) + # We will now split our extent into pixels. Find the intermediate + # coordinates of the first and last pixel center along each axis + if shape is None: + res = expand_res(res) + ra1, ra2, nx = pixelize_1d(extent[0], res=res[0], offs=offs[0], periodic=periodic[0]) + dec1,dec2,ny = pixelize_1d(extent[1], res=res[1], offs=offs[1], periodic=periodic[1]) + elif res is None: + ra1, ra2, nx = pixelize_1d(extent[0], n=shape[-2], offs=offs[0], periodic=periodic[0]) + dec1,dec2,ny = pixelize_1d(extent[1], n=shape[-2], offs=offs[1], periodic=periodic[0]) + else: + raise ValueError("Either res or shape must be given to build a pixelization") + # Now that we have the intermediate coordinates of our endpoints, we + # can calculate cdelt and crpix + owcs = pwcs.deepcopy() + owcs.wcs.cdelt = [(ra2-ra1)/(nx-1), (dec2-dec1)/(ny-1)] + owcs.wcs.crpix = [1+(pwcs.wcs.crval[0]-ra1)/owcs.wcs.cdelt[0],1+(pwcs.wcs.crval[1]-dec1)/owcs.wcs.cdelt[1]] + return (ny,nx), owcs def explicit(naxis=2, **args): wcs = WCS(naxis=naxis) @@ -56,6 +101,20 @@ def explicit(naxis=2, **args): setattr(wcs.wcs, key, args[key]) return wcs +def expand_res(res, signs=None, flip=False): + """If res is not None, expand it to length 2. If it wasn't already + length 2, the RA sign will be inverted. If flip is True, the res order + will be flipped before expanding""" + if res is None: return res + # Bleh, compensate for later flip + if signs is None: signs = [1,-1] if flip else [-1,1] + res = np.atleast_1d(res) + assert res.ndim == 1, "Invalid res shape" + assert len(res) <= 2, "Invalid res length" + if flip: res, signs = res[::-1], signs[::-1] + if res.size == 1: res = np.array(signs)*res[0] + return res + def describe(wcs): """Since astropy.wcs.WCS objects do not have a useful str implementation, this function provides a relpacement.""" @@ -104,7 +163,10 @@ def is_plain(wcs): def is_cyl(wcs): """Returns True if the wcs represents a cylindrical coordinate system""" - return get_proj(wcs) in ["cyp","cea","car","mer"] and wcs.wcs.crval[1] == 0 + return get_proj(wcs) in ["cyp","cea","car","mer"] + +def is_separable(wcs): + return is_cyl(wcs) and wcs.wcs.crval[1] == 0 def get_proj(wcs): if isinstance(wcs, str): return wcs @@ -112,6 +174,11 @@ def get_proj(wcs): toks = wcs.wcs.ctype[0].split("-") return toks[-1].lower() if len(toks) >= 2 else "" +def parse_system(system, variant=None): + toks = system.split(":") + if len(toks) > 1: return toks[0].lower(), toks[1] + else: return toks[0].lower(), variant + def scale(wcs, scale=1, rowmajor=False, corner=False): """Scales the linear pixel density of a wcs by the given factor, which can be specified per axis. This is the same as dividing the pixel size by the same number.""" @@ -126,18 +193,198 @@ def scale(wcs, scale=1, rowmajor=False, corner=False): wcs.wcs.crpix += 0.5 return wcs -def expand_res(res, default_dirs=[1,-1]): - res = np.atleast_1d(res) - assert res.ndim == 1, "Invalid res shape" - if res.size == 1: - return np.array(default_dirs)*res +#def expand_res(res, default_dirs=[1,-1]): +# res = np.atleast_1d(res) +# assert res.ndim == 1, "Invalid res shape" +# if res.size == 1: +# return np.array(default_dirs)*res +# else: +# return res + +########################### +#### Helper functions ##### +########################### + +def is_azimuthal(system): return system.lower() in ["arc", "zea", "sin", "tan", "azp", "slp", "stg", "zpn", "air"] + +def default_crval(system): + if is_azimuthal(system): return [0,90] + else: return [0,0] + +def default_extent(system): + """Return the horizontal and vertical extent of the full sky in degrees. + For some systems the full sky is not representable, in which case a + reasonable compromise is returned""" + system = system.lower() + if system in ["", "plain"]: return [1,1] + # Cylindrical + if system == "car": return [360,180] + elif system == "cea": return [360,360/np.pi] + elif system == "mer": return [360,360] # traditional dec range gives square map + # Zenithal + elif system == "arc": return [360,360] + elif system == "zea": return [720/np.pi,720/np.pi] + elif system == "sin": return [360/np.pi,360/np.pi] # only orthographic supported + elif system == "tan": return [360,360] # goes down to 0.158° above the horizon + # Pseudo-cyl + elif system == "mol": return [720*2**0.5/np.pi,360*2**0.5/np.pi] + elif system == "ait": return [720*2**0.5/np.pi,360*2**0.5/np.pi] + else: raise ValueError("Unsupported system '%s'" % str(system)) + +def default_variant(system): + system = system.lower() + return "fejer1" if system in ["car","plain",""] else "any" + +def extent2bounds(extent): return [[-e/h,e/h] for e in extent] + +def is_periodic(system): + system = system.lower() + if is_azimuthal(system) or system in ["", "plain"]: + return [False,False] + else: + return [True,False] + +def parse_variant(name): + name = name.lower() + if name == "safe": rule = "hh,hh" # fully-downgrade safe. What fejer1 should have been + elif name == "fejer1": rule = "00,hh" # stays SHTable after downgrade, but not pix-comp with raw @ that res + elif name == "cc": rule = "00,00" # what we used for pre-DR6. Cannot SHT after downgrade + elif name == "any": rule = "**,**" + else: rule = name + toks = rule.split(",") + if len(toks) != 2 or len(toks[0]) != 2 or len(toks[1]) != 2: + raise ValueError("Could not recognize pixelization variant '%s'" % (str(name))) + left = {"0": 0, "h": 0.5, "*": None} + right = {"0": 0, "h":-0.5, "*": None} + try: + return [[left[tok[0]],right[tok[1]]] for tok in toks] + except KeyError: + raise ValueError("Invalid character in rule '%s'" % str(rule)) + +class PixelizationError(Exception): pass + +def pixelize_1d(w, n=None, res=None, offs=None, periodic=False, adjust=False, sign=1, tol=1e-6, eps=1e-6): + """Figure out how to align pixels along an interval w long such + that there are either n pixels or the resolution is res, and with + the given pixel offsets from the edges. Returns the coordinates of + the center of the first and last pixel.""" + # FIXME: This is a bit poorly thought out. The concept of being + # able to adjust the range and that of having wildcard edges should + # be separate, but the way we've done things now there's no room in + # parse_variant to say something like "0h" but adjustable. For now + # I just have to hardcode that "**" means "00" but adjustable. + o1, o2 = offs if offs is not None else (None, None) + if res is not None: + if res < 0: res, sign = -res, -sign + if o1 is None and o2 is None: + o1 = o2 = 0 + adjust = True + if o2 is None: + # Add a tiny number to avoid having a rounding discontinuity for common values + # off w, res and o1 + n = int(w/res+1-o1+eps) + elif o1 is None: + n = int(w/res+1+o2+eps) + else: + # Both given! Can we satisfy requirement? + nf = w/res+1-(o1-o2) + n = int(nf+eps) + if adjust: + # We're free to redefine w so things work + w = (n+(o1+o2)-1)*res + else: + # Complain if the resolution and offsets are incompatible + if not np.abs(n-nf)world routines seems to +# have to be 1 rather than the 0 one would expect. For example, +# if wcs is CAR(crval=(0,0),crpix=(0,0),cdelt=(1,1)), then +# pix2world(0,0,1) is (0,0) while pix2world(0,0,0) is (-1,-1). +# +# No! the problem is that everythin in the fits header counts from 1, +# so the default crpix should be (1,1), not (0,0). With +# CAR(crval(0,0),crpix(1,1),cdelt(1,1)) we get +# pix2world(1,1,1) = (0,0) and pix2world(0,0,0) = (0,0) + +# Useful stuff to be able to do: +# * Create a wcs from (point,res) +# * Create a wcs from (box,res) +# * Create a wcs from (box,shape) +# * Create a wcs from (point,res,shape) +# Can support this by taking arguments: +# pos: point[2] or box[2,2], mandatory +# res: num or [2], optional +# shape: [2], optional +# In cases where shape is not specified, the implied +# shape can be recovered from the wcs and a box by computing +# the pixel coordinates of the corners. So we don't need to return +# it. + +# 1. Construct wcs from box, res (and return shape?) +# 2. Construct wcs from box, shape +# 3. Construct wcs from point, res (this is the most primitive version) # I need to update this to work better with full-sky stuff. # Should be easy to construct something that's part of a # clenshaw-curtis or fejer sky. +deg2rad = np.pi/180 +rad2deg = 1/deg2rad + def plain(pos, res=None, shape=None, rowmajor=False, ref=None): """Set up a plain coordinate system (non-cyclical)""" pos, res, shape, mid = validate(pos, res, shape, rowmajor, default_dirs=[1,1]) @@ -297,32 +544,3 @@ def _apply_zenithal_ref(w, ref): def angdist(lon1,lat1,lon2,lat2): return np.arccos(np.cos(lat1)*np.cos(lat2)*(np.cos(lon1)*np.cos(lon2)+np.sin(lon1)*np.sin(lon2))+np.sin(lat1)*np.sin(lat2)) - -def fix_wcs(wcs, axis=0): - """Returns a new WCS object which has had the reference pixel moved to the - middle of the possible pixel space.""" - res = wcs.deepcopy() - # Find the center ra manually: mean([crval - crpix*cdelt, crval + (-crpix+shape)*cdelt]) - # = crval + (-crpix+shape/2)*cdelt - # What pixel does this correspond to? - # crpix2 = crpix + (crval2-crval)/cdelt - # But that requires shape. Can we do without it? Yes, let's use the - # biggest possible shape. n = 360/cdelt - n = abs(360/wcs.wcs.cdelt[axis]) - delta_ra = wcs.wcs.cdelt[axis]*(n/2-wcs.wcs.crpix[axis]) - delta_pix = delta_ra/wcs.wcs.cdelt[axis] - res.wcs.crval[axis] += delta_ra - res.wcs.crpix[axis] += delta_pix - repr(res.wcs) # wcs not properly updated if I don't do this - return res - -def fix_cdelt(wcs): - """Return a new wcs with pc and cd replaced by cdelt""" - owcs = wcs.deepcopy() - if wcs.wcs.has_cd(): - del owcs.wcs.cd, owcs.wcs.pc - owcs.wcs.cdelt *= np.diag(wcs.wcs.cd) - elif wcs.wcs.has_pc(): - del owcs.wcs.cd, owcs.wcs.pc - owcs.wcs.cdelt *= np.diag(wcs.wcs.pc) - return owcs diff --git a/pyproject.toml b/pyproject.toml index 0b37a5c0..5af9fe21 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = ['meson-python', 'numpy', 'cython', 'versioneer[toml]', 'build'] [project] name = 'pixell' -version = "0.26.2" +version = "0.27.1" description = "A rectangular pixel map manipulation and harmonic analysis library derived from Sigurd Naess' enlib." readme = 'README.rst' requires-python = '>=3.9' diff --git a/tests/test_pixell.py b/tests/test_pixell.py index 05359067..8b207816 100644 --- a/tests/test_pixell.py +++ b/tests/test_pixell.py @@ -137,7 +137,7 @@ def get_geometries(yml_section): geos = {} for g in yml_section: if g['type']=='fullsky': - geos[g['name']] = enmap.fullsky_geometry(res=np.deg2rad(g['res_arcmin']/60.),proj=g['proj']) + geos[g['name']] = enmap.fullsky_geometry(res=np.deg2rad(g['res_arcmin']/60.),proj=g['proj'],variant="CC") elif g['type']=='pickle': geos[g['name']] = pickle.load(open(DATA_PREFIX+"%s"%g['filename'],'rb')) else: @@ -198,7 +198,7 @@ def get_extraction_test_results(yaml_file): lens_version = '071123' def get_offset_result(res=1.,dtype=np.float64,seed=1): - shape,wcs = enmap.fullsky_geometry(res=np.deg2rad(res)) + shape,wcs = enmap.fullsky_geometry(res=np.deg2rad(res), variant="CC") shape = (3,) + shape obs_pos = enmap.posmap(shape, wcs) np.random.seed(seed) @@ -207,7 +207,7 @@ def get_offset_result(res=1.,dtype=np.float64,seed=1): return obs_pos,grad,raw_pos def get_lens_result(res=1.,lmax=400,dtype=np.float64,seed=1): - shape,wcs = enmap.fullsky_geometry(res=np.deg2rad(res)) + shape,wcs = enmap.fullsky_geometry(res=np.deg2rad(res), variant="CC") shape = (3,) + shape # ells = np.arange(lmax) ps_cmb,ps_lens = powspec.read_camb_scalar(DATA_PREFIX+"test_scalCls.dat") @@ -485,7 +485,7 @@ def test_fullsky_geometry(self): print("Testing full sky geometry...") test_res_arcmin = 0.5 shape,wcs = enmap.fullsky_geometry(res=np.deg2rad(test_res_arcmin/60.),proj='car') - assert shape[0]==21601 and shape[1]==43200 + assert shape[0]==21600 and shape[1]==43200 assert abs(enmap.area(shape,wcs) - 4*np.pi) < 1e-6 def test_pixels(self): @@ -633,8 +633,8 @@ def test_project_nn(self): shape2,wcs2 = enmap.fullsky_geometry(res=np.deg2rad(6/60.),proj='car') shape3,wcs3 = enmap.fullsky_geometry(res=np.deg2rad(24/60.),proj='car') imap = enmap.ones(shape,wcs) - omap2 = enmap.project(imap,shape2,wcs2,order=0,mode='wrap') - omap3 = enmap.project(imap,shape3,wcs3,order=0,mode='wrap') + omap2 = enmap.project(imap,shape2,wcs2,order=0,border='wrap') + omap3 = enmap.project(imap,shape3,wcs3,order=0,border='wrap') assert np.all(np.isclose(omap2,1)) assert np.all(np.isclose(omap3,1)) @@ -1093,7 +1093,7 @@ def test_thumbnails(self): assert np.all(np.isclose(diff,0,atol=1e-3)) def test_tilemap(self): - shape, wcs = enmap.fullsky_geometry(30*utils.degree) + shape, wcs = enmap.fullsky_geometry(30*utils.degree, variant="CC") assert shape == (7,12) geo = tilemap.geometry((3,)+shape, wcs, tile_shape=(2,2)) assert len(geo.active) == 0