Skip to content

Commit

Permalink
Ignore skims (#54)
Browse files Browse the repository at this point in the history
* allow ignoring skims

* update min numba

* add tests and doc for ignore

* shared memory pre-init

* non-dask reload of data

* improved logging

* log bytes loaded so far

* fix bug

* docs and tests
  • Loading branch information
jpn-- authored May 13, 2024
1 parent fb31a68 commit 05c3ce9
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 109 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy >= 1.19
pandas >= 1.2
pyarrow >= 3.0.0
xarray >= 0.20.0
numba >= 0.54
numba >= 0.57
numexpr
filelock
sphinx-autosummary-accessors
2 changes: 1 addition & 1 deletion envs/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- xarray
- dask
- networkx
- numba>=0.54
- numba>=0.57
- numexpr
- sparse
- filelock
Expand Down
312 changes: 223 additions & 89 deletions sharrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import hashlib
import logging
import re
from collections.abc import Hashable, Mapping, Sequence
from typing import Any
import time
from collections.abc import Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any

import numpy as np
import pandas as pd
Expand All @@ -17,8 +18,13 @@
from .accessors import register_dataset_method
from .aster import extract_all_name_tokens
from .categorical import _Categorical # noqa
from .shared_memory import si_units
from .table import Table

if TYPE_CHECKING:
import openmatrix


logger = logging.getLogger("sharrow")

well_known_names = {
Expand Down Expand Up @@ -283,7 +289,7 @@ def from_table(


def from_omx(
omx,
omx: openmatrix.File,
index_names=("otaz", "dtaz"),
indexes="one-based",
renames=None,
Expand Down Expand Up @@ -385,14 +391,23 @@ def from_omx(
return xr.Dataset.from_dict(d)


def _should_ignore(ignore, x):
if ignore is not None:
for i in ignore:
if re.match(i, x):
return True
return False


def from_omx_3d(
omx,
omx: openmatrix.File | str | Iterable[openmatrix.File | str],
index_names=("otaz", "dtaz", "time_period"),
indexes=None,
*,
time_periods=None,
time_period_sep="__",
max_float_precision=32,
ignore=None,
):
"""
Create a Dataset from an OMX file with an implicit third dimension.
Expand Down Expand Up @@ -427,6 +442,12 @@ def from_omx_3d(
precision, generally to save memory if they were stored as double
precision but that level of detail is unneeded in the present
application.
ignore : list-like, optional
A list of regular expressions that will be used to filter out
variables from the dataset. If any of the regular expressions
match the name of a variable, that variable will not be included
in the loaded dataset. This is useful for excluding variables that
are not needed in the current application.
Returns
-------
Expand All @@ -435,103 +456,216 @@ def from_omx_3d(
if not isinstance(omx, (list, tuple)):
omx = [omx]

# handle both larch.OMX and openmatrix.open_file versions
if "larch" in type(omx[0]).__module__:
omx_shape = omx[0].shape
omx_lookup = omx[0].lookup
else:
omx_shape = omx[0].shape()
omx_lookup = omx[0].root["lookup"]
omx_data = []
omx_data_map = {}
for n, i in enumerate(omx):
if "larch" in type(i).__module__:
omx_data.append(i.data)
for k in i.data._v_children:
omx_data_map[k] = n
else:
omx_data.append(i.root["data"])
for k in i.root["data"]._v_children:
omx_data_map[k] = n

import dask.array
use_file_handles = []
opened_file_handles = []
for filename in omx:
if isinstance(filename, str):
import openmatrix

data_names = list(omx_data_map.keys())
n1, n2 = omx_shape
if indexes is None:
# default reads mapping if only one lookup is included, otherwise one-based
if len(omx_lookup._v_children) == 1:
ranger = None
indexes = list(omx_lookup._v_children)[0]
h = openmatrix.open_file(filename)
opened_file_handles.append(h)
use_file_handles.append(h)
else:
use_file_handles.append(filename)
omx = use_file_handles

try:
# handle both larch.OMX and openmatrix.open_file versions
if "larch" in type(omx[0]).__module__:
omx_shape = omx[0].shape
omx_lookup = omx[0].lookup
else:
omx_shape = omx[0].shape()
omx_lookup = omx[0].root["lookup"]
omx_data = []
omx_data_map = {}
for n, i in enumerate(omx):
if "larch" in type(i).__module__:
omx_data.append(i.data)
for k in i.data._v_children:
omx_data_map[k] = n
else:
omx_data.append(i.root["data"])
for k in i.root["data"]._v_children:
omx_data_map[k] = n

import dask.array

data_names = list(omx_data_map.keys())
if ignore is not None:
if isinstance(ignore, str):
ignore = [ignore]
data_names = [i for i in data_names if not _should_ignore(ignore, i)]
n1, n2 = omx_shape
if indexes is None:
# default reads mapping if only one lookup is included, otherwise one-based
if len(omx_lookup._v_children) == 1:
ranger = None
indexes = list(omx_lookup._v_children)[0]
else:
ranger = one_based
elif indexes == "one-based":
ranger = one_based
elif indexes == "one-based":
ranger = one_based
elif indexes == "zero-based":
ranger = zero_based
elif indexes in set(omx_lookup._v_children):
ranger = None
else:
raise NotImplementedError(
"only one-based, zero-based, and named indexes are implemented"
)
if ranger is not None:
r1 = ranger(n1)
r2 = ranger(n2)
else:
r1 = r2 = pd.Index(omx_lookup[indexes])
elif indexes == "zero-based":
ranger = zero_based
elif indexes in set(omx_lookup._v_children):
ranger = None
else:
raise NotImplementedError(
"only one-based, zero-based, and named indexes are implemented"
)
if ranger is not None:
r1 = ranger(n1)
r2 = ranger(n2)
else:
r1 = r2 = pd.Index(omx_lookup[indexes])

if time_periods is None:
raise ValueError("must give time periods explicitly")
if time_periods is None:
raise ValueError("must give time periods explicitly")

time_periods_map = {t: n for n, t in enumerate(time_periods)}
time_periods_map = {t: n for n, t in enumerate(time_periods)}

pending_3d = {}
content = {}
pending_3d = {}
content = {}

for k in data_names:
if time_period_sep in k:
base_k, time_k = k.split(time_period_sep, 1)
if base_k not in pending_3d:
pending_3d[base_k] = [None] * len(time_periods)
pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array(
omx_data[omx_data_map[k]][k]
)
else:
content[k] = xr.DataArray(
dask.array.from_array(omx_data[omx_data_map[k]][k]),
dims=index_names[:2],
for k in data_names:
if time_period_sep in k:
base_k, time_k = k.split(time_period_sep, 1)
if base_k not in pending_3d:
pending_3d[base_k] = [None] * len(time_periods)
pending_3d[base_k][time_periods_map[time_k]] = dask.array.from_array(
omx_data[omx_data_map[k]][k]
)
else:
content[k] = xr.DataArray(
dask.array.from_array(omx_data[omx_data_map[k]][k]),
dims=index_names[:2],
coords={
index_names[0]: r1,
index_names[1]: r2,
},
)
for base_k, darrs in pending_3d.items():
# find a prototype array
prototype = None
for i in darrs:
prototype = i
if prototype is not None:
break
if prototype is None:
raise ValueError("no prototype")
darrs_ = [
(i if i is not None else dask.array.zeros_like(prototype))
for i in darrs
]
content[base_k] = xr.DataArray(
dask.array.stack(darrs_, axis=-1),
dims=index_names,
coords={
index_names[0]: r1,
index_names[1]: r2,
index_names[2]: time_periods,
},
)
for base_k, darrs in pending_3d.items():
# find a prototype array
prototype = None
for i in darrs:
prototype = i
if prototype is not None:
break
if prototype is None:
raise ValueError("no prototype")
darrs_ = [
(i if i is not None else dask.array.zeros_like(prototype)) for i in darrs
]
content[base_k] = xr.DataArray(
dask.array.stack(darrs_, axis=-1),
dims=index_names,
coords={
index_names[0]: r1,
index_names[1]: r2,
index_names[2]: time_periods,
},
)
for i in content:
if np.issubdtype(content[i].dtype, np.floating):
if content[i].dtype.itemsize > max_float_precision / 8:
content[i] = content[i].astype(f"float{max_float_precision}")
return xr.Dataset(content)
for i in content:
if np.issubdtype(content[i].dtype, np.floating):
if content[i].dtype.itemsize > max_float_precision / 8:
content[i] = content[i].astype(f"float{max_float_precision}")
return xr.Dataset(content)
finally:
for h in opened_file_handles:
h.close()


def reload_from_omx_3d(
dataset: xr.Dataset,
omx: Iterable[str],
*,
time_period_sep="__",
ignore=None,
) -> None:
"""
Reload the content of a dataset from OMX files.
This loads the data from the OMX files into the dataset, replacing
the existing data in the dataset. The dataset must have been created
by `from_omx_3d` or a similar function. Note that `from_omx_3d` will
create a dataset backed by `dask.array` objects; this function allows for
loading the data without going through dask, which may have poor performance
on some platforms.
Parameters
----------
dataset : xr.Dataset
The dataset to reload into.
omx : Iterable[str]
The list of OMX file names to load from.
time_period_sep : str, default "__"
The separator used to identify time periods in the dataset.
ignore : list-like, optional
A list of regular expressions that will be used to filter out
variables from the dataset. If any of the regular expressions
match the name of a variable, that variable will not be included
in the load process. This is useful for excluding variables that
are not found in the target dataset.
"""
if isinstance(ignore, str):
ignore = [ignore]

use_file_handles = []
opened_file_handles = []
for filename in omx:
if isinstance(filename, str):
import openmatrix

h = openmatrix.open_file(filename)
opened_file_handles.append(h)
use_file_handles.append(h)
else:
use_file_handles.append(filename)
omx = use_file_handles

bytes_loaded = 0

try:
t0 = time.time()
for filename, f in zip(omx, use_file_handles):
if isinstance(filename, str):
logger.info(f"loading into dataset from {filename}")
for data_name in f.root.data._v_children:
if _should_ignore(ignore, data_name):
logger.info(f"ignoring {data_name}")
continue
t1 = time.time()
filters = f.root.data[data_name].filters
filter_note = f"{filters.complib}/{filters.complevel}"

if time_period_sep in data_name:
data_name_x, data_name_t = data_name.split(time_period_sep, 1)
if len(dataset[data_name_x].dims) != 3:
raise ValueError(
f"dataset variable {data_name_x} has "
f"{len(dataset[data_name_x].dims)} dimensions, expected 3"
)
raw = dataset[data_name_x].sel(time_period=data_name_t).data
raw[:, :] = f.root.data[data_name][:, :]
else:
if len(dataset[data_name].dims) != 2:
raise ValueError(
f"dataset variable {data_name} has "
f"{len(dataset[data_name].dims)} dimensions, expected 2"
)
raw = dataset[data_name].data
raw[:, :] = f.root.data[data_name][:, :]
bytes_loaded += raw.nbytes
logger.info(
f"loaded {data_name} ({filter_note}) to dataset "
f"in {time.time() - t1:.2f}s, {si_units(bytes_loaded)}"
)
logger.info(f"loading to dataset complete in {time.time() - t0:.2f}s")
finally:
for h in opened_file_handles:
h.close()


def from_amx(
Expand Down
Loading

0 comments on commit 05c3ce9

Please sign in to comment.