Skip to content

Commit

Permalink
Downversion (#6)
Browse files Browse the repository at this point in the history
* fix link in docs

* make compat with py3.7

* add to testing matrix

* conditional dependency

* trigger on develop

* install deps

* ignore 3.8 for now
  • Loading branch information
jpn-- authored Mar 6, 2022
1 parent e4251b1 commit f0c0b92
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 41 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ name: sharrow testing

on:
push:
branches: [ main ]
branches: [ main, develop ]
tags:
- 'v[0-9]+.[0-9]+**'
pull_request:
branches: [ main ]
branches: [ main, develop ]
tags:
- 'v[0-9]+.[0-9]+**'
workflow_dispatch:
Expand All @@ -18,7 +18,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
python-version: ["3.9"]
python-version: ["3.7", "3.9"]
defaults:
run:
shell: bash -l {0}
Expand All @@ -41,7 +41,7 @@ jobs:
auto-update-conda: false
- name: Install sharrow
run: |
python -m pip install --no-deps -e .
python -m pip install -e .
- name: Conda checkup
run: |
conda info -a
Expand Down
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Convenience
Dataset
--------------------------------------------------------------------------------
Sharrow uses the :py:class:`xarray.Dataset` class extensively. Refer to the
`xarray documentation <https://docs.xarray.dev/en/stable/>` for standard usage.
`xarray documentation <https://docs.xarray.dev/en/stable/>`_ for standard usage.
The attributes and methods documented here are added to :py:class:`xarray.Dataset`
when you import sharrow.

Expand Down
11 changes: 11 additions & 0 deletions docs/walkthrough/one-dim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@
"sh.__version__"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa0564dd",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.version_info < (3,8)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
7 changes: 4 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ url = https://github.com/ActivitySim/sharrow
packages = find:
zip_safe = False
include_package_data = True
python_requires = >=3.9
python_requires = >=3.7
install_requires =
numpy >= 1.19
pandas >= 1.2
pyarrow >= 3.0.0
xarray >= 0.20.0
numba >= 0.53
numba >= 0.51.2
numexpr
filelock
dask
networkx
astunparse;python_version<'3.9'

[flake8]
exclude =
Expand All @@ -28,4 +29,4 @@ exclude =
docs/_build,
sharrow/__init__.py
max-line-length = 160
extend-ignore = E203
extend-ignore = E203, E731
77 changes: 56 additions & 21 deletions sharrow/aster.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import ast
import io
import logging
import sys
import tokenize

try:
from ast import unparse
except ImportError:
from astunparse import unparse as _unparse

unparse = lambda *args: _unparse(*args).strip("\n")

logger = logging.getLogger("sharrow.aster")

if sys.version_info >= (3, 8):
ast_Constant_Type = ast.Constant
ast_String_value = lambda x: x
else:
ast_Constant_Type = (ast.Index, ast.Constant)
ast_String_value = lambda x: x.s if isinstance(x, ast.Str) else x


def _isNone(c):
if c is None:
return True
if isinstance(c, ast.Constant) and c.value is None:
if isinstance(c, ast_Constant_Type) and c.value is None:
return True
return False

Expand Down Expand Up @@ -319,7 +334,7 @@ def log_event(self, tag, node1=None, node2=None):
)
elif node2 is None:
try:
unparsed = ast.unparse(node1)
unparsed = unparse(node1)
except: # noqa: E722
unparsed = f"{type(node1)} not unparseable"
logger.log(
Expand All @@ -328,11 +343,11 @@ def log_event(self, tag, node1=None, node2=None):
)
else:
try:
unparsed1 = ast.unparse(node1)
unparsed1 = unparse(node1)
except: # noqa: E722
unparsed1 = f"{type(node1).__name__} not unparseable"
try:
unparsed2 = ast.unparse(node2)
unparsed2 = unparse(node2)
except: # noqa: E722
unparsed2 = f"{type(node2).__name__} not unparseable"
logger.log(
Expand Down Expand Up @@ -428,13 +443,15 @@ def _replacement(
keywords=[],
)
else:
if (scale := digital_encoding.get("scale", 1)) != 1:
scale = digital_encoding.get("scale", 1)
offset = digital_encoding.get("offset", 0)
if scale != 1:
result = ast.BinOp(
left=result,
op=ast.Mult(),
right=ast.Num(scale),
)
if offset := digital_encoding.get("offset", 0):
if offset:
result = ast.BinOp(
left=result,
op=ast.Add(),
Expand All @@ -447,23 +464,31 @@ def visit_Subscript(self, node):
if isinstance(node.value, ast.Name):
if (
node.value.id == self.spacename
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
and isinstance(node.slice, ast_Constant_Type)
and isinstance(ast_String_value(node.slice.value), str)
):
self.log_event(f"visit_Subscript(Constant {node.slice.value})")
return self._replacement(node.slice.value, node.ctx, node)
return self._replacement(
ast_String_value(node.slice.value), node.ctx, node
)
if (
node.value.id == self.rawalias
and isinstance(node.slice, ast.Constant)
and isinstance(node.slice.value, str)
and node.slice.value in self.spacevars
and isinstance(node.slice, ast_Constant_Type)
and isinstance(ast_String_value(node.slice.value), str)
and ast_String_value(node.slice.value) in self.spacevars
):
result = ast.Subscript(
value=ast.Name(id=self.rawname, ctx=ast.Load()),
slice=ast.Constant(self.spacevars[node.slice.value]),
slice=ast.Constant(
self.spacevars[ast_String_value(node.slice.value)]
),
ctx=node.ctx,
)
self.log_event(f"visit_Subscript(Raw {node.slice.value})", node, result)
self.log_event(
f"visit_Subscript(Raw {ast_String_value(node.slice.value)})",
node,
result,
)
return result
self.log_event("visit_Subscript(no change)", node)
return node
Expand Down Expand Up @@ -550,9 +575,14 @@ def visit_Call(self, node):
if isinstance(node.func, ast.Attribute) and node.func.attr == "reverse":
if isinstance(node.func.value, ast.Name):
if node.func.value.id == self.spacename:
if len(node.args) == 1 and isinstance(node.args[0], ast.Constant):
if len(node.args) == 1 and isinstance(
node.args[0], ast_Constant_Type
):
result = self._replacement(
node.args[0].value, node.func.ctx, None, transpose_lead=True
ast_String_value(node.args[0].value),
node.func.ctx,
None,
transpose_lead=True,
)
# handle clip as a method
if isinstance(node.func, ast.Attribute) and node.func.attr == "clip":
Expand Down Expand Up @@ -623,12 +653,17 @@ def visit_Call(self, node):
if isinstance(node.func, ast.Attribute) and node.func.attr == "max":
if isinstance(node.func.value, ast.Name):
if node.func.value.id == self.spacename:
if len(node.args) == 1 and isinstance(node.args[0], ast.Constant):
if len(node.args) == 1 and isinstance(
node.args[0], ast_Constant_Type
):
forward = self._replacement(
node.args[0].value, node.func.ctx, None
ast_String_value(node.args[0].value), node.func.ctx, None
)
backward = self._replacement(
node.args[0].value, node.func.ctx, None, transpose_lead=True
ast_String_value(node.args[0].value),
node.func.ctx,
None,
transpose_lead=True,
)
result = ast.Call(
func=ast.Name("max", ctx=ast.Load()),
Expand Down Expand Up @@ -673,7 +708,7 @@ def expression_for_numba(
digital_encodings=None,
prefer_name=None,
):
return ast.unparse(
return unparse(
RewriteForNumba(
spacename,
dim_slots,
Expand Down Expand Up @@ -703,7 +738,7 @@ def __call__(self, expr):
)
target = a.targets[0].id
new_tree = a.value
result = ast.unparse(new_tree)
result = unparse(new_tree)
self._cache[expr] = (target, result)
return target, result

Expand Down
9 changes: 3 additions & 6 deletions sharrow/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,9 +729,9 @@ def _flow_hash_push(x):
for k in extra_hash_data:
_flow_hash_push(k)

_flow_hash_push(f"{boundscheck=}")
_flow_hash_push(f"{error_model=}")
_flow_hash_push(f"{fastmath=}")
_flow_hash_push(f"boundscheck={boundscheck}")
_flow_hash_push(f"error_model={error_model}")
_flow_hash_push(f"fastmath={fastmath}")

self.flow_hash = base64.b32encode(flow_hash.digest()).decode()
self.flow_hash_audit = "]\n# [".join(flow_hash_audit)
Expand Down Expand Up @@ -792,9 +792,6 @@ def init_sub_funcs(
_dims = spacearrays._variables[k1].dims
except AttributeError:
_dims = spacearrays[k1].dims

print(f" YO {_dims=}")

dim_slots[k1] = [index_slots[z] for z in _dims]
try:
digital_encodings = spacearrays.digital_encoding.info()
Expand Down
10 changes: 9 additions & 1 deletion sharrow/relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@

from .dataset import Dataset, construct

try:
from ast import unparse
except ImportError:
from astunparse import unparse as _unparse

unparse = lambda *args: _unparse(*args).strip("\n")


logger = logging.getLogger("sharrow")

well_known_names = {
Expand Down Expand Up @@ -1098,7 +1106,7 @@ def _arg_tokenizer(self, spacename, spacearray, exclude_dims=None):
parent_data, parent_name, exclude_dims=exclude_dims
)
try:
upside = ", ".join(ast.unparse(t) for t in upside_ast)
upside = ", ".join(unparse(t) for t in upside_ast)
except: # noqa: E722
for t in upside_ast:
print(f"t:{t}")
Expand Down
7 changes: 6 additions & 1 deletion sharrow/shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
import logging
import os
import pickle
from multiprocessing.shared_memory import ShareableList, SharedMemory

import dask
import dask.array as da
import numpy as np
import xarray as xr

try:
from multiprocessing.shared_memory import ShareableList, SharedMemory
except ImportError:
ShareableList, SharedMemory = None, None


__GLOBAL_MEMORY_ARRAYS = {}
__GLOBAL_MEMORY_LISTS = {}

Expand Down
5 changes: 2 additions & 3 deletions sharrow/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,10 @@ def from_quilt(cls, path, blockname=None):
else:
stopper = 1e99
n = 0
rowfile = lambda n: os.path.join(path, f"block.{n:03d}.rows") # noqa: E731
colfile = lambda n: os.path.join(path, f"block.{n:03d}.cols") # noqa: E731
rowfile = lambda n: os.path.join(path, f"block.{n:03d}.rows")
colfile = lambda n: os.path.join(path, f"block.{n:03d}.cols")
builder = None
look = True
print(f"{stopper=}")
while look and n <= stopper:
look = False
if os.path.exists(rowfile(n)):
Expand Down
7 changes: 6 additions & 1 deletion sharrow/tests/test_relationships.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys

import numpy as np
import pandas as pd
from numpy.random import SeedSequence, default_rng
from pytest import raises
from pytest import mark, raises

import sharrow
from sharrow import Dataset, DataTree, example_data
Expand Down Expand Up @@ -247,6 +249,9 @@ def _get_target(q):
q.put(skims_.SOV_TIME.sum())


@mark.skipif(
sys.version_info < (3, 8), reason="shared memory requires python3.8 or higher"
)
def test_shared_memory():

skims = example_data.get_skims()
Expand Down

0 comments on commit f0c0b92

Please sign in to comment.