From 05c3ce9157f77217c44bd6f10540ea20e153a218 Mon Sep 17 00:00:00 2001 From: Jeffrey Newman Date: Mon, 13 May 2024 16:02:19 -0500 Subject: [PATCH] Ignore skims (#54) * allow ignoring skims * update min numba * add tests and doc for ignore * shared memory pre-init * non-dask reload of data * improved logging * log bytes loaded so far * fix bug * docs and tests --- docs/requirements.txt | 2 +- envs/testing.yml | 2 +- sharrow/dataset.py | 312 +++++++++++++++++++++++---------- sharrow/example_data.py | 5 + sharrow/shared_memory.py | 82 +++++++-- sharrow/tests/test_datasets.py | 68 +++++++ 6 files changed, 362 insertions(+), 109 deletions(-) diff --git a/docs/requirements.txt b/docs/requirements.txt index 821305e..09e8b90 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ numpy >= 1.19 pandas >= 1.2 pyarrow >= 3.0.0 xarray >= 0.20.0 -numba >= 0.54 +numba >= 0.57 numexpr filelock sphinx-autosummary-accessors diff --git a/envs/testing.yml b/envs/testing.yml index 581eee1..04db37d 100644 --- a/envs/testing.yml +++ b/envs/testing.yml @@ -10,7 +10,7 @@ dependencies: - xarray - dask - networkx - - numba>=0.54 + - numba>=0.57 - numexpr - sparse - filelock diff --git a/sharrow/dataset.py b/sharrow/dataset.py index 85b7de2..fc2bcc4 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -5,8 +5,9 @@ import hashlib import logging import re -from collections.abc import Hashable, Mapping, Sequence -from typing import Any +import time +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any import numpy as np import pandas as pd @@ -17,8 +18,13 @@ from .accessors import register_dataset_method from .aster import extract_all_name_tokens from .categorical import _Categorical # noqa +from .shared_memory import si_units from .table import Table +if TYPE_CHECKING: + import openmatrix + + logger = logging.getLogger("sharrow") well_known_names = { @@ -283,7 +289,7 @@ def from_table( def from_omx( - omx, + omx: openmatrix.File, index_names=("otaz", "dtaz"), indexes="one-based", renames=None, @@ -385,14 +391,23 @@ def from_omx( return xr.Dataset.from_dict(d) +def _should_ignore(ignore, x): + if ignore is not None: + for i in ignore: + if re.match(i, x): + return True + return False + + def from_omx_3d( - omx, + omx: openmatrix.File | str | Iterable[openmatrix.File | str], index_names=("otaz", "dtaz", "time_period"), indexes=None, *, time_periods=None, time_period_sep="__", max_float_precision=32, + ignore=None, ): """ Create a Dataset from an OMX file with an implicit third dimension. @@ -427,6 +442,12 @@ def from_omx_3d( precision, generally to save memory if they were stored as double precision but that level of detail is unneeded in the present application. + ignore : list-like, optional + A list of regular expressions that will be used to filter out + variables from the dataset. If any of the regular expressions + match the name of a variable, that variable will not be included + in the loaded dataset. This is useful for excluding variables that + are not needed in the current application. Returns ------- @@ -435,103 +456,216 @@ def from_omx_3d( if not isinstance(omx, (list, tuple)): omx = [omx] - # handle both larch.OMX and openmatrix.open_file versions - if "larch" in type(omx[0]).__module__: - omx_shape = omx[0].shape - omx_lookup = omx[0].lookup - else: - omx_shape = omx[0].shape() - omx_lookup = omx[0].root["lookup"] - omx_data = [] - omx_data_map = {} - for n, i in enumerate(omx): - if "larch" in type(i).__module__: - omx_data.append(i.data) - for k in i.data._v_children: - omx_data_map[k] = n - else: - omx_data.append(i.root["data"]) - for k in i.root["data"]._v_children: - omx_data_map[k] = n - - import dask.array + use_file_handles = [] + opened_file_handles = [] + for filename in omx: + if isinstance(filename, str): + import openmatrix - data_names = list(omx_data_map.keys()) - n1, n2 = omx_shape - if indexes is None: - # default reads mapping if only one lookup is included, otherwise one-based - if len(omx_lookup._v_children) == 1: - ranger = None - indexes = list(omx_lookup._v_children)[0] + h = openmatrix.open_file(filename) + opened_file_handles.append(h) + use_file_handles.append(h) + else: + use_file_handles.append(filename) + omx = use_file_handles + + try: + # handle both larch.OMX and openmatrix.open_file versions + if "larch" in type(omx[0]).__module__: + omx_shape = omx[0].shape + omx_lookup = omx[0].lookup else: + omx_shape = omx[0].shape() + omx_lookup = omx[0].root["lookup"] + omx_data = [] + omx_data_map = {} + for n, i in enumerate(omx): + if "larch" in type(i).__module__: + omx_data.append(i.data) + for k in i.data._v_children: + omx_data_map[k] = n + else: + omx_data.append(i.root["data"]) + for k in i.root["data"]._v_children: + omx_data_map[k] = n + + import dask.array + + data_names = list(omx_data_map.keys()) + if ignore is not None: + if isinstance(ignore, str): + ignore = [ignore] + data_names = [i for i in data_names if not _should_ignore(ignore, i)] + n1, n2 = omx_shape + if indexes is None: + # default reads mapping if only one lookup is included, otherwise one-based + if len(omx_lookup._v_children) == 1: + ranger = None + indexes = list(omx_lookup._v_children)[0] + else: + ranger = one_based + elif indexes == "one-based": ranger = one_based - elif indexes == "one-based": - ranger = one_based - elif indexes == "zero-based": - ranger = zero_based - elif indexes in set(omx_lookup._v_children): - ranger = None - else: - raise NotImplementedError( - "only one-based, zero-based, and named indexes are implemented" - ) - if ranger is not None: - r1 = ranger(n1) - r2 = ranger(n2) - else: - r1 = r2 = pd.Index(omx_lookup[indexes]) + elif indexes == "zero-based": + ranger = zero_based + elif indexes in set(omx_lookup._v_children): + ranger = None + else: + raise NotImplementedError( + "only one-based, zero-based, and named indexes are implemented" + ) + if ranger is not None: + r1 = ranger(n1) + r2 = ranger(n2) + else: + r1 = r2 = pd.Index(omx_lookup[indexes]) - if time_periods is None: - raise ValueError("must give time periods explicitly") + if time_periods is None: + raise ValueError("must give time periods explicitly") - time_periods_map = {t: n for n, t in enumerate(time_periods)} + time_periods_map = {t: n for n, t in enumerate(time_periods)} - pending_3d = {} - content = {} + pending_3d = {} + content = {} - for k in data_names: - if time_period_sep in k: - base_k, time_k = k.split(time_period_sep, 1) - if base_k not in pending_3d: - pending_3d[base_k] = [None] * len(time_periods) - pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array( - omx_data[omx_data_map[k]][k] - ) - else: - content[k] = xr.DataArray( - dask.array.from_array(omx_data[omx_data_map[k]][k]), - dims=index_names[:2], + for k in data_names: + if time_period_sep in k: + base_k, time_k = k.split(time_period_sep, 1) + if base_k not in pending_3d: + pending_3d[base_k] = [None] * len(time_periods) + pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array( + omx_data[omx_data_map[k]][k] + ) + else: + content[k] = xr.DataArray( + dask.array.from_array(omx_data[omx_data_map[k]][k]), + dims=index_names[:2], + coords={ + index_names[0]: r1, + index_names[1]: r2, + }, + ) + for base_k, darrs in pending_3d.items(): + # find a prototype array + prototype = None + for i in darrs: + prototype = i + if prototype is not None: + break + if prototype is None: + raise ValueError("no prototype") + darrs_ = [ + (i if i is not None else dask.array.zeros_like(prototype)) + for i in darrs + ] + content[base_k] = xr.DataArray( + dask.array.stack(darrs_, axis=-1), + dims=index_names, coords={ index_names[0]: r1, index_names[1]: r2, + index_names[2]: time_periods, }, ) - for base_k, darrs in pending_3d.items(): - # find a prototype array - prototype = None - for i in darrs: - prototype = i - if prototype is not None: - break - if prototype is None: - raise ValueError("no prototype") - darrs_ = [ - (i if i is not None else dask.array.zeros_like(prototype)) for i in darrs - ] - content[base_k] = xr.DataArray( - dask.array.stack(darrs_, axis=-1), - dims=index_names, - coords={ - index_names[0]: r1, - index_names[1]: r2, - index_names[2]: time_periods, - }, - ) - for i in content: - if np.issubdtype(content[i].dtype, np.floating): - if content[i].dtype.itemsize > max_float_precision / 8: - content[i] = content[i].astype(f"float{max_float_precision}") - return xr.Dataset(content) + for i in content: + if np.issubdtype(content[i].dtype, np.floating): + if content[i].dtype.itemsize > max_float_precision / 8: + content[i] = content[i].astype(f"float{max_float_precision}") + return xr.Dataset(content) + finally: + for h in opened_file_handles: + h.close() + + +def reload_from_omx_3d( + dataset: xr.Dataset, + omx: Iterable[str], + *, + time_period_sep="__", + ignore=None, +) -> None: + """ + Reload the content of a dataset from OMX files. + + This loads the data from the OMX files into the dataset, replacing + the existing data in the dataset. The dataset must have been created + by `from_omx_3d` or a similar function. Note that `from_omx_3d` will + create a dataset backed by `dask.array` objects; this function allows for + loading the data without going through dask, which may have poor performance + on some platforms. + + Parameters + ---------- + dataset : xr.Dataset + The dataset to reload into. + omx : Iterable[str] + The list of OMX file names to load from. + time_period_sep : str, default "__" + The separator used to identify time periods in the dataset. + ignore : list-like, optional + A list of regular expressions that will be used to filter out + variables from the dataset. If any of the regular expressions + match the name of a variable, that variable will not be included + in the load process. This is useful for excluding variables that + are not found in the target dataset. + """ + if isinstance(ignore, str): + ignore = [ignore] + + use_file_handles = [] + opened_file_handles = [] + for filename in omx: + if isinstance(filename, str): + import openmatrix + + h = openmatrix.open_file(filename) + opened_file_handles.append(h) + use_file_handles.append(h) + else: + use_file_handles.append(filename) + omx = use_file_handles + + bytes_loaded = 0 + + try: + t0 = time.time() + for filename, f in zip(omx, use_file_handles): + if isinstance(filename, str): + logger.info(f"loading into dataset from {filename}") + for data_name in f.root.data._v_children: + if _should_ignore(ignore, data_name): + logger.info(f"ignoring {data_name}") + continue + t1 = time.time() + filters = f.root.data[data_name].filters + filter_note = f"{filters.complib}/{filters.complevel}" + + if time_period_sep in data_name: + data_name_x, data_name_t = data_name.split(time_period_sep, 1) + if len(dataset[data_name_x].dims) != 3: + raise ValueError( + f"dataset variable {data_name_x} has " + f"{len(dataset[data_name_x].dims)} dimensions, expected 3" + ) + raw = dataset[data_name_x].sel(time_period=data_name_t).data + raw[:, :] = f.root.data[data_name][:, :] + else: + if len(dataset[data_name].dims) != 2: + raise ValueError( + f"dataset variable {data_name} has " + f"{len(dataset[data_name].dims)} dimensions, expected 2" + ) + raw = dataset[data_name].data + raw[:, :] = f.root.data[data_name][:, :] + bytes_loaded += raw.nbytes + logger.info( + f"loaded {data_name} ({filter_note}) to dataset " + f"in {time.time() - t1:.2f}s, {si_units(bytes_loaded)}" + ) + logger.info(f"loading to dataset complete in {time.time() - t0:.2f}s") + finally: + for h in opened_file_handles: + h.close() def from_amx( diff --git a/sharrow/example_data.py b/sharrow/example_data.py index 308d9a9..69572c6 100644 --- a/sharrow/example_data.py +++ b/sharrow/example_data.py @@ -4,6 +4,11 @@ import pandas as pd +def get_skims_filename() -> str: + """Return the path to the example skims file.""" + return os.path.join(os.path.dirname(__file__), "example_data", "skims.omx") + + def get_skims(): import openmatrix diff --git a/sharrow/shared_memory.py b/sharrow/shared_memory.py index 0098c06..4ec3b6d 100644 --- a/sharrow/shared_memory.py +++ b/sharrow/shared_memory.py @@ -31,7 +31,11 @@ def si_units(x, kind="B", digits=3, shift=1000): tiers = ["n", "ยต", "m", "", "K", "M", "G", "T", "P", "E", "Z", "Y"] tier = 3 - sign = "-" if x < 0 else "" + try: + sign = "-" if x < 0 else "" + except TypeError: + # x is not a number, just return it + return x x = abs(x) if x > 0: while x > shift and tier < len(tiers): @@ -249,7 +253,13 @@ def delete_shared_memory_files(key): delete_shared_memory_files(key) def to_shared_memory( - self, key=None, mode="r+", _dupe=True, dask_scheduler="threads" + self, + key=None, + mode="r+", + _dupe=True, + dask_scheduler="threads", + pre_init=False, + load=True, ): """ Load this Dataset into shared memory. @@ -270,8 +280,17 @@ def to_shared_memory( See numpy.memmap() for details. dask_scheduler : str, default 'threads' The scheduler to use when loading dask arrays into shared memory. - Typically "threads" for multi-threaded reads or "synchronous" + Typically, this is "threads" for multithreaded reads or "synchronous" for single-threaded reads. See dask.compute() for details. + pre_init : bool, default False + If True, the shared memory buffer will be pre-initialized with zeros. + This is generally not necessary, but can be useful for debugging. + load : bool, default True + If True, load the data into shared memory immediately, using dask. + If False, defer loading until later. Deferred tasks are stored in + the `shm.tasks` attribute of the resulting Dataset object, but do not + necessarily need to be run if data can be loaded using alternative + methods (e.g. `sharrow.dataset.reload_from_omx_3d`). Returns ------- @@ -294,7 +313,7 @@ def to_shared_memory( def emit(k, a, is_coord): nonlocal names, wrappers, sizes, position if sparse is not None and isinstance(a.data, sparse.GCXS): - logger.info(f"preparing sparse array {a.name}") + logger.debug(f"preparing sparse array {a.name}") wrappers.append( { "sparse": True, @@ -316,7 +335,7 @@ def emit(k, a, is_coord): ) a_nbytes = a.data.nbytes else: - logger.info(f"preparing dense array {a.name}") + logger.debug(f"preparing dense array {a.name}") wrappers.append( { "dims": a.dims, @@ -345,12 +364,17 @@ def emit(k, a, is_coord): mem = create_shared_memory_array(key, size=position) - logger.info("declaring shared memory buffer") + logger.debug("declaring shared memory buffer") if key.startswith("memmap:"): buffer = memoryview(mem) else: buffer = mem.buf + if pre_init: + logger.debug("pre-initializing shared memory buffer") + # gross init with all zeros + buffer[:] = b"\0" * len(buffer) + tasks = [] task_names = [] for w in wrappers: @@ -386,7 +410,7 @@ def emit(k, a, is_coord): mem_arr_i[:] = ad.indices[:] mem_arr_p[:] = ad.indptr[:] else: - logger.info(f"preparing load task: {_name} ({si_units(_size)})") + logger.debug(f"preparing load task: {_name} ({si_units(_size)})") mem_arr = np.ndarray( shape=a.shape, dtype=a.dtype, buffer=buffer[_pos : _pos + _size] ) @@ -395,16 +419,10 @@ def emit(k, a, is_coord): task_names.append(_name) else: mem_arr[:] = a[:] - if tasks: - t = time.time() - logger.info(f"running {len(tasks)} dask data load tasks") - if dask_scheduler == "synchronous": - for task, task_name in zip(tasks, task_names): - logger.info(f"running load task: {task_name}") - dask.compute(task, scheduler=dask_scheduler) - else: - dask.compute(tasks, scheduler=dask_scheduler) - logger.info(f"completed dask data load in {time.time()-t:.3f} seconds") + if tasks and load: + self.tasks = tasks + self.task_names = task_names + self.run_tasks(dask_scheduler=dask_scheduler) if key.startswith("memmap:"): mem.flush() @@ -413,7 +431,35 @@ def emit(k, a, is_coord): create_shared_list( [pickle.dumps(self._obj.attrs)] + [pickle.dumps(i) for i in wrappers], key ) - return type(self).from_shared_memory(key, own_data=mem, mode=mode) + result = type(self).from_shared_memory(key, own_data=mem, mode=mode) + if tasks and not load: + # attach incompleted tasks to the result + result.shm.tasks = tasks + result.shm.task_names = task_names + result.shm._buffer = buffer + result.shm._position = position + return result + + def run_tasks(self, dask_scheduler="threads"): + """Run any deferred dask tasks.""" + if not hasattr(self, "tasks"): + return + else: + tasks = self.tasks + if not hasattr(self, "task_names"): + task_names = ["untitled" for _ in tasks] + else: + task_names = self.task_names + t = time.time() + logger.info(f"running {len(tasks)} dask data load tasks") + if dask_scheduler == "synchronous": + for task, task_name in zip(tasks, task_names): + logger.info(f"running load task: {task_name}") + dask.compute(task, scheduler=dask_scheduler) + else: + dask.compute(tasks, scheduler=dask_scheduler) + logger.info(f"completed tasks in {time.time() - t:.3f} seconds") + del self.tasks @property def shared_memory_key(self): diff --git a/sharrow/tests/test_datasets.py b/sharrow/tests/test_datasets.py index d2da374..72d2104 100644 --- a/sharrow/tests/test_datasets.py +++ b/sharrow/tests/test_datasets.py @@ -1,9 +1,11 @@ +import secrets import tempfile from pathlib import Path import numpy as np import openmatrix import pandas as pd +import xarray as xr from pytest import approx import sharrow as sh @@ -65,3 +67,69 @@ def income_cat(i): recovered_df = hd.single_dim.to_pandas() pd.testing.assert_frame_equal(hhs, recovered_df) + + +def test_load_with_ignore(): + filename = sh.example_data.get_skims_filename() + with openmatrix.open_file(filename) as f: + skims = sh.dataset.from_omx_3d( + f, + index_names=("otaz", "dtaz", "time_period"), + indexes=None, + time_periods=["EA", "AM", "MD", "PM", "EV"], + time_period_sep="__", + max_float_precision=32, + ) + assert "DRV_COM_WLK_FAR" in skims.variables + + with openmatrix.open_file(filename) as f: + skims1 = sh.dataset.from_omx_3d( + f, + index_names=("otaz", "dtaz", "time_period"), + indexes=None, + time_periods=["EA", "AM", "MD", "PM", "EV"], + time_period_sep="__", + max_float_precision=32, + ignore=["DRV_COM_WLK_.*"], + ) + assert "DRV_COM_WLK_FAR" not in skims1.variables + + with openmatrix.open_file(filename) as f: + skims2 = sh.dataset.from_omx_3d( + f, + index_names=("otaz", "dtaz", "time_period"), + indexes=None, + time_periods=["EA", "AM", "MD", "PM", "EV"], + time_period_sep="__", + max_float_precision=32, + ignore="DRV_COM_WLK_.*", + ) + print(skims2) + assert "DISTBIKE" in skims2.variables + assert "DRV_COM_WLK_FAR" not in skims2.variables + + +def test_deferred_load_to_shared_memory(): + """ + Test of deferred loading of data into shared memory. + + Checks that skim data is loaded correctly into shared memory + when using the `to_shared_memory` method with `load=False`, followed by + a call to `reload_from_omx_3d`. + """ + from sharrow.example_data import get_skims_filename + + skims_filename = get_skims_filename() + with openmatrix.open_file(skims_filename) as f: + d0 = sh.dataset.from_omx_3d( + f, + index_names=("otaz", "dtaz", "time_period"), + time_periods=["EA", "AM", "MD", "PM", "EV"], + max_float_precision=32, + ) + token = "skims" + secrets.token_hex(5) + d1 = d0.shm.to_shared_memory(token, mode="r", load=False) + sh.dataset.reload_from_omx_3d(d1, [skims_filename]) + xr.testing.assert_equal(d0, d1) + d2 = xr.Dataset.shm.from_shared_memory(token) + xr.testing.assert_equal(d0, d2)