Skip to content

Commit

Permalink
Better hardware fix (#39)
Browse files Browse the repository at this point in the history
fix for apple silicon, and deprecation warning on Dataset.dims
  • Loading branch information
jpn-- authored Jan 10, 2024
1 parent 5d07145 commit b6455af
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 28 deletions.
27 changes: 13 additions & 14 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,7 @@ def __initialize_1(
all_raw_names |= attribute_pairs.get(self.tree.root_node_name, set())
all_raw_names |= subscript_pairs.get(self.tree.root_node_name, set())

dimensions_ordered = presorted(self.tree.dims, self.dim_order, self.dim_exclude)
dimensions_ordered = presorted(self.tree.sizes, self.dim_order, self.dim_exclude)
index_slots = {i: n for n, i in enumerate(dimensions_ordered)}
self.arg_name_positions = index_slots
self.arg_names = dimensions_ordered
Expand All @@ -1074,7 +1074,7 @@ def __initialize_1(
self._used_aux_vars.append(aux_var)

subspace_names = set()
for (k, _) in self.tree.subspaces_iter():
for k, _ in self.tree.subspaces_iter():
subspace_names.add(k)
for k in self.tree.subspace_fallbacks:
subspace_names.add(k)
Expand All @@ -1083,7 +1083,7 @@ def __initialize_1(
)
self._optional_get_tokens = []
if optional_get_tokens:
for (_spacename, _varname) in optional_get_tokens:
for _spacename, _varname in optional_get_tokens:
found = False
if (
_spacename in self.tree.subspaces
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def _index_slots(self):
return {
i: n
for n, i in enumerate(
presorted(self.tree.dims, self.dim_order, self.dim_exclude)
presorted(self.tree.sizes, self.dim_order, self.dim_exclude)
)
}

Expand All @@ -1219,7 +1219,7 @@ def init_sub_funcs(
index_slots = {
i: n
for n, i in enumerate(
presorted(self.tree.dims, self.dim_order, self.dim_exclude)
presorted(self.tree.sizes, self.dim_order, self.dim_exclude)
)
}
self.arg_name_positions = index_slots
Expand Down Expand Up @@ -1665,7 +1665,6 @@ def __initialize_2(
with rewrite(
os.path.join(self.cache_dir, self.name, "__init__.py"), "wt"
) as f_code:

f_code.write(
textwrap.dedent(
f"""
Expand Down Expand Up @@ -1719,13 +1718,15 @@ def __initialize_2(
f_code.write("\n\n# machinery code\n\n")

if self.tree.relationships_are_digitized:
if with_root_node_name is None:
with_root_node_name = self.tree.root_node_name

if with_root_node_name is None:
with_root_node_name = self.tree.root_node_name

root_dims = list(
presorted(
self.tree._graph.nodes[with_root_node_name]["dataset"].dims,
self.tree._graph.nodes[with_root_node_name]["dataset"].sizes,
self.dim_order,
self.dim_exclude,
)
Expand Down Expand Up @@ -1803,7 +1804,6 @@ def __initialize_2(
raise ValueError(f"invalid n_root_dims {n_root_dims}")

else:

raise RuntimeError("digitization is now required")

f_code.write(blacken(textwrap.dedent(line_template)))
Expand Down Expand Up @@ -2048,11 +2048,11 @@ def _iload_raw(
kwargs["mask"] = mask

if self.with_root_node_name is None:
tree_root_dims = rg.root_dataset.dims
tree_root_dims = rg.root_dataset.sizes
else:
tree_root_dims = rg._graph.nodes[self.with_root_node_name][
"dataset"
].dims
].sizes
argshape = [
tree_root_dims[i]
for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude)
Expand Down Expand Up @@ -2266,12 +2266,12 @@ def _load(

if self.with_root_node_name is None:
use_dims = list(
presorted(source.root_dataset.dims, self.dim_order, self.dim_exclude)
presorted(source.root_dataset.sizes, self.dim_order, self.dim_exclude)
)
else:
use_dims = list(
presorted(
source._graph.nodes[self.with_root_node_name]["dataset"].dims,
source._graph.nodes[self.with_root_node_name]["dataset"].sizes,
self.dim_order,
self.dim_exclude,
)
Expand Down Expand Up @@ -2441,7 +2441,6 @@ def _load(
{k: result[:, n] for n, k in enumerate(self._raw_functions.keys())}
)
elif as_dataarray:

if result_squeeze:
result = squeeze(result, result_squeeze)
result_p = squeeze(result_p, result_squeeze)
Expand Down Expand Up @@ -2849,7 +2848,7 @@ def init_streamer(self, source=None, dtype=None):

selected_args = tuple(general_mapping[k] for k in named_args)
len_self_raw_functions = len(self._raw_functions)
tree_root_dims = source.root_dataset.dims
tree_root_dims = source.root_dataset.sizes
argshape = tuple(
tree_root_dims[i]
for i in presorted(tree_root_dims, self.dim_order, self.dim_exclude)
Expand Down
1 change: 0 additions & 1 deletion sharrow/nested_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def _utility_to_probability(
logprob, # float output shape=[nodes]
probability, # float output shape=[nodes]
):

for up in range(n_alts, utility.size):
up_nest = up - n_alts
n_children_for_parent = len_slots[up_nest]
Expand Down
7 changes: 3 additions & 4 deletions sharrow/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,6 @@ def add_dataset(self, name, dataset, relationships=(), as_root=False):
self.digitize_relationships(inplace=True)

def add_items(self, items):

from collections.abc import Mapping, Sequence

if isinstance(items, Sequence):
Expand Down Expand Up @@ -707,7 +706,6 @@ def _getitem(
just_node_name=False,
dim_names_from_top=False,
):

if isinstance(item, (list, tuple)):
from .dataset import Dataset

Expand Down Expand Up @@ -790,7 +788,7 @@ def _getitem(
# path_indexing = self._graph.edges[path[-1]].get('indexing')
t1 = None
# intermediate nodes on path
for (e, e_next) in zip(path[:-1], path[1:]):
for e, e_next in zip(path[:-1], path[1:]):
r = self._get_relationship(e)
r_next = self._get_relationship(e_next)
if t1 is None:
Expand Down Expand Up @@ -971,6 +969,8 @@ def dims(self):
dims[name] = length
return xr.core.utils.Frozen(dims)

sizes = dims # alternate name

def dims_detail(self):
"""
Report on the names and sizes of dimensions in all Dataset nodes.
Expand Down Expand Up @@ -1395,7 +1395,6 @@ def relationships_are_digitized(self):
def _arg_tokenizer(
self, spacename, spacearray, spacearrayname, exclude_dims=None, blends=None
):

if blends is None:
blends = {}

Expand Down
2 changes: 0 additions & 2 deletions sharrow/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


def si_units(x, kind="B", digits=3, shift=1000):

# nano micro milli kilo mega giga tera peta exa zeta yotta
tiers = ["n", "µ", "m", "", "K", "M", "G", "T", "P", "E", "Z", "Y"]

Expand Down Expand Up @@ -219,7 +218,6 @@ def delete_shared_memory_files(key):

@xr.register_dataset_accessor("shm")
class SharedMemDatasetAccessor:

_parent_class = xr.Dataset

def __init__(self, xarray_obj):
Expand Down
23 changes: 18 additions & 5 deletions sharrow/sparse.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math

import numba as nb
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -197,15 +199,26 @@ def blenders(self):
return b


@nb.generated_jit(nopython=True)
# fastmath must be false to ensure NaNs are detected here.
# wrapping this as such allows fastmath to be turned on in outer functions
# but not lose the ability to check for NaNs. Older versions of this function
# checked whether the float cast to an integer was -9223372036854775808, but
# that turns out to be not compatible with all hardware (i.e. Apple Silicon).
def isnan_fast_safe(x):
if isinstance(x, float):
return math.isnan(x)
elif isinstance(x, str):
return x == "\u0015"
else:
return False


@nb.extending.overload(isnan_fast_safe, jit_options={"fastmath": False})
def ol_isnan_fast_safe(x):
if isinstance(x, nb.types.Float):

def func(x):
if int(x) == -9223372036854775808:
return True
else:
return False
return math.isnan(x)

return func
elif isinstance(x, (nb.types.UnicodeType, nb.types.UnicodeCharSeq)):
Expand Down
3 changes: 2 additions & 1 deletion sharrow/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import pytest
import xarray as xr
import pandas as pd

from sharrow.dataset import construct


Expand Down
1 change: 0 additions & 1 deletion sharrow/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def omx_to_zarr(
time_periods=None,
time_period_sep="__",
):

bucket = {}

r1 = r2 = None
Expand Down

0 comments on commit b6455af

Please sign in to comment.