diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 79a20ba..3b9775e 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -12,7 +12,28 @@ on: workflow_dispatch: jobs: + + fmt: + name: formatting quality + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11' + - name: Install Ruff + run: | + python -m pip install ruff + - name: Lint with Ruff + run: | + # code quality check, stop the build for any errors + ruff check . --show-fixes --exit-non-zero-on-fix + test: + needs: fmt name: ${{ matrix.os }} py${{ matrix.python-version }} runs-on: ${{ matrix.os }} strategy: @@ -23,13 +44,9 @@ jobs: run: shell: bash -l {0} steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v4 + - name: Install Python and Dependencies + uses: conda-incubator/setup-miniconda@v3 with: miniforge-variant: Mambaforge miniforge-version: latest @@ -48,10 +65,11 @@ jobs: conda list - name: Lint with Ruff run: | + # code quality check # stop the build if there are Python syntax errors or undefined names - ruff check . --select=E9,F63,F7,F82 --statistics - # exit-zero treats all errors as warnings. - ruff check . --exit-zero --statistics + ruff check . --select=E9,F63,F7,F82 --no-fix + # stop the build for any other configured Ruff linting errors + ruff check . --show-fixes --exit-non-zero-on-fix - name: Test with pytest run: | python -m pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 30a97c0..3e23f78 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -13,19 +13,14 @@ repos: hooks: - id: nbstripout -- repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.274 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.11 hooks: + # Run the linter. - id: ruff - args: [--fix, --exit-non-zero-on-fix] - -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black", "--filter-files"] - -- repo: https://github.com/psf/black - rev: 23.3.0 - hooks: - - id: black + types_or: [ python, pyi, jupyter ] + args: [ --fix ] + # Run the formatter. + - id: ruff-format + types_or: [ python, pyi, jupyter ] diff --git a/README.md b/README.md index b2e4e8f..c72cb6c 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,18 @@ # sharrow numba for ActivitySim-style spec files +## Building a Wheel + +To build a wheel for sharrow, you need to have `build` installed. You can +install it with `python -m pip install build`. + +Then run the builder: + +```shell +python -m build . +``` + + ## Building the documentation Building the documentation for sharrow requires JupyterBook. @@ -8,3 +20,26 @@ Building the documentation for sharrow requires JupyterBook. ```shell jupyterbook build docs ``` + +## Developer Note + +This repository's continuous integration testing will use `ruff` to check code +quality. There is a pre-commit hook that will run `ruff` on all staged files +to ensure that they pass the quality checks. To install and use this hook, +run the following commands: + +```shell +python -m pip install pre-commit # if needed +pre-commit install +``` + +Then, when you try to make a commit, your code will be checked locally to ensure +that your code passes the quality checks. If you want to run the checks manually, +you can do so with the following command: + +```shell +pre-commit run --all-files +``` + +If you don't use `pre-commit`, a service will run the checks for you when you +open a pull request, and make fixes to your code when possible. diff --git a/docs/walkthrough/encoding.ipynb b/docs/walkthrough/encoding.ipynb index 7f07159..b1359a0 100644 --- a/docs/walkthrough/encoding.ipynb +++ b/docs/walkthrough/encoding.ipynb @@ -23,7 +23,8 @@ "source": [ "# HIDDEN\n", "import warnings\n", - "warnings.filterwarnings(\"ignore\", category=DeprecationWarning) " + "\n", + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)" ] }, { @@ -38,9 +39,9 @@ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", - "from io import StringIO\n", "\n", "import sharrow as sh\n", + "\n", "sh.__version__" ] }, @@ -57,6 +58,7 @@ "source": [ "# check versions\n", "import packaging\n", + "\n", "assert packaging.version.parse(xr.__version__) >= packaging.version.parse(\"0.20.2\")" ] }, @@ -146,7 +148,7 @@ "metadata": {}, "outputs": [], "source": [ - "from sharrow.digital_encoding import array_encode, array_decode" + "from sharrow.digital_encoding import array_decode, array_encode" ] }, { @@ -165,7 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "skims.DIST.values[:2,:3]" + "skims.DIST.values[:2, :3]" ] }, { @@ -210,7 +212,7 @@ "outputs": [], "source": [ "distance_encoded = array_encode(skims.DIST, scale=0.01, offset=0)\n", - "distance_encoded.values[:2,:3]" + "distance_encoded.values[:2, :3]" ] }, { @@ -227,10 +229,14 @@ "# TEST encoding\n", "assert distance_encoded.dtype == np.int16\n", "np.testing.assert_array_equal(\n", - " distance_encoded.values[:2,:3],\n", - " np.array([[12, 24, 44], [37, 14, 28]], dtype=np.int16)\n", + " distance_encoded.values[:2, :3],\n", + " np.array([[12, 24, 44], [37, 14, 28]], dtype=np.int16),\n", ")\n", - "assert distance_encoded.attrs['digital_encoding'] == {'scale': 0.01, 'offset': 0, 'missing_value': None}" + "assert distance_encoded.attrs[\"digital_encoding\"] == {\n", + " \"scale\": 0.01,\n", + " \"offset\": 0,\n", + " \"missing_value\": None,\n", + "}" ] }, { @@ -249,9 +255,7 @@ "metadata": {}, "outputs": [], "source": [ - "skims_encoded = skims.assign(\n", - " {'DIST': array_encode(skims.DIST, scale=0.01, offset=0)}\n", - ")" + "skims_encoded = skims.assign({\"DIST\": array_encode(skims.DIST, scale=0.01, offset=0)})" ] }, { @@ -271,7 +275,9 @@ "metadata": {}, "outputs": [], "source": [ - "skims_encoded = skims_encoded.digital_encoding.set(['DISTWALK', 'DISTBIKE'], scale=0.01, offset=0)" + "skims_encoded = skims_encoded.digital_encoding.set(\n", + " [\"DISTWALK\", \"DISTBIKE\"], scale=0.01, offset=0\n", + ")" ] }, { @@ -305,9 +311,9 @@ "source": [ "# TEST\n", "assert skims_encoded.digital_encoding.info() == {\n", - " 'DIST': {'scale': 0.01, 'offset': 0, 'missing_value': None},\n", - " 'DISTBIKE': {'scale': 0.01, 'offset': 0, 'missing_value': None},\n", - " 'DISTWALK': {'scale': 0.01, 'offset': 0, 'missing_value': None},\n", + " \"DIST\": {\"scale\": 0.01, \"offset\": 0, \"missing_value\": None},\n", + " \"DISTBIKE\": {\"scale\": 0.01, \"offset\": 0, \"missing_value\": None},\n", + " \"DISTWALK\": {\"scale\": 0.01, \"offset\": 0, \"missing_value\": None},\n", "}" ] }, @@ -330,16 +336,16 @@ "metadata": {}, "outputs": [], "source": [ - "pairs = pd.DataFrame({'orig': [0,0,0,1,1,1], 'dest': [0,1,2,0,1,2]})\n", + "pairs = pd.DataFrame({\"orig\": [0, 0, 0, 1, 1, 1], \"dest\": [0, 1, 2, 0, 1, 2]})\n", "tree = sh.DataTree(\n", - " base=pairs, \n", - " skims=skims.drop_dims('time_period'), \n", + " base=pairs,\n", + " skims=skims.drop_dims(\"time_period\"),\n", " relationships=(\n", " \"base.orig -> skims.otaz\",\n", " \"base.dest -> skims.dtaz\",\n", " ),\n", ")\n", - "flow = tree.setup_flow({'d1': 'DIST', 'd2': 'DIST**2'})\n", + "flow = tree.setup_flow({\"d1\": \"DIST\", \"d2\": \"DIST**2\"})\n", "arr = flow.load()\n", "arr" ] @@ -361,14 +367,14 @@ "outputs": [], "source": [ "tree_enc = sh.DataTree(\n", - " base=pairs, \n", - " skims=skims_encoded.drop_dims('time_period'), \n", + " base=pairs,\n", + " skims=skims_encoded.drop_dims(\"time_period\"),\n", " relationships=(\n", " \"base.orig -> skims.otaz\",\n", " \"base.dest -> skims.dtaz\",\n", " ),\n", ")\n", - "flow_enc = tree_enc.setup_flow({'d1': 'DIST', 'd2': 'DIST**2'})\n", + "flow_enc = tree_enc.setup_flow({\"d1\": \"DIST\", \"d2\": \"DIST**2\"})\n", "arr_enc = flow_enc.load()\n", "arr_enc" ] @@ -440,7 +446,7 @@ "metadata": {}, "outputs": [], "source": [ - "skims.WLK_LOC_WLK_FAR.values[:2,:3,:]" + "skims.WLK_LOC_WLK_FAR.values[:2, :3, :]" ] }, { @@ -460,7 +466,7 @@ "outputs": [], "source": [ "wlwfare_enc = array_encode(skims.WLK_LOC_WLK_FAR, bitwidth=8, by_dict=True)\n", - "wlwfare_enc.values[:2,:3,:]" + "wlwfare_enc.values[:2, :3, :]" ] }, { @@ -470,7 +476,7 @@ "metadata": {}, "outputs": [], "source": [ - "wlwfare_enc.attrs['digital_encoding']['dictionary']" + "wlwfare_enc.attrs[\"digital_encoding\"][\"dictionary\"]" ] }, { @@ -487,18 +493,18 @@ "# TEST encoding\n", "assert wlwfare_enc.dtype == np.uint8\n", "np.testing.assert_array_equal(\n", - " wlwfare_enc.values[:2,:3,:],\n", - " np.array([[[0, 0, 0, 0, 0],\n", - " [1, 2, 2, 1, 2],\n", - " [1, 2, 2, 1, 2]],\n", - "\n", - " [[1, 1, 2, 2, 1],\n", - " [0, 0, 0, 0, 0],\n", - " [1, 2, 2, 1, 2]]], dtype=np.uint8)\n", + " wlwfare_enc.values[:2, :3, :],\n", + " np.array(\n", + " [\n", + " [[0, 0, 0, 0, 0], [1, 2, 2, 1, 2], [1, 2, 2, 1, 2]],\n", + " [[1, 1, 2, 2, 1], [0, 0, 0, 0, 0], [1, 2, 2, 1, 2]],\n", + " ],\n", + " dtype=np.uint8,\n", + " ),\n", ")\n", "np.testing.assert_array_equal(\n", - " wlwfare_enc.attrs['digital_encoding']['dictionary'],\n", - " np.array([ 0., 152., 474., 626.], dtype=np.float32)\n", + " wlwfare_enc.attrs[\"digital_encoding\"][\"dictionary\"],\n", + " np.array([0.0, 152.0, 474.0, 626.0], dtype=np.float32),\n", ")" ] }, @@ -561,12 +567,14 @@ "outputs": [], "source": [ "skims1 = skims.digital_encoding.set(\n", - " ['WLK_LOC_WLK_FAR', \n", - " 'WLK_EXP_WLK_FAR', \n", - " 'WLK_HVY_WLK_FAR', \n", - " 'DRV_LOC_WLK_FAR',\n", - " 'DRV_HVY_WLK_FAR',\n", - " 'DRV_EXP_WLK_FAR'],\n", + " [\n", + " \"WLK_LOC_WLK_FAR\",\n", + " \"WLK_EXP_WLK_FAR\",\n", + " \"WLK_HVY_WLK_FAR\",\n", + " \"DRV_LOC_WLK_FAR\",\n", + " \"DRV_HVY_WLK_FAR\",\n", + " \"DRV_EXP_WLK_FAR\",\n", + " ],\n", " joint_dict=True,\n", ")" ] @@ -591,8 +599,7 @@ "outputs": [], "source": [ "skims1 = skims1.digital_encoding.set(\n", - " ['DISTBIKE', \n", - " 'DISTWALK'],\n", + " [\"DISTBIKE\", \"DISTWALK\"],\n", " joint_dict=\"jointWB\",\n", ")" ] @@ -638,9 +645,9 @@ "outputs": [], "source": [ "tree1 = sh.DataTree(\n", - " base=pairs, \n", - " skims=skims1, \n", - " rskims=skims1, \n", + " base=pairs,\n", + " skims=skims1,\n", + " rskims=skims1,\n", " relationships=(\n", " \"base.orig -> skims.otaz\",\n", " \"base.dest -> skims.dtaz\",\n", @@ -648,15 +655,18 @@ " \"base.dest -> rskims.otaz\",\n", " ),\n", ")\n", - "flow1 = tree1.setup_flow({\n", - " 'd1': 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]', \n", - " 'd2': 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]**2',\n", - " 'w1': 'skims.DISTWALK',\n", - " 'w2': 'skims.reverse(\"DISTWALK\")',\n", - " 'w3': 'rskims.DISTWALK',\n", - " 'x1': 'skims.DIST',\n", - " 'x2': 'skims.reverse(\"DIST\")',\n", - "}, hashing_level=2)\n", + "flow1 = tree1.setup_flow(\n", + " {\n", + " \"d1\": 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]',\n", + " \"d2\": 'skims[\"WLK_LOC_WLK_FAR\", \"AM\"]**2',\n", + " \"w1\": \"skims.DISTWALK\",\n", + " \"w2\": 'skims.reverse(\"DISTWALK\")',\n", + " \"w3\": \"rskims.DISTWALK\",\n", + " \"x1\": \"skims.DIST\",\n", + " \"x2\": 'skims.reverse(\"DIST\")',\n", + " },\n", + " hashing_level=2,\n", + ")\n", "arr1 = flow1.load_dataframe()\n", "arr1" ] @@ -669,13 +679,72 @@ "outputs": [], "source": [ "# TEST\n", - "assert (arr1 == np.array([[ 0.00000e+00, 0.00000e+00, 1.20000e-01, 1.20000e-01, 1.20000e-01, 1.20000e-01, 1.20000e-01],\n", - " [ 4.74000e+02, 2.24676e+05, 2.40000e-01, 3.70000e-01, 3.70000e-01, 2.40000e-01, 3.70000e-01],\n", - " [ 4.74000e+02, 2.24676e+05, 4.40000e-01, 5.70000e-01, 5.70000e-01, 4.40000e-01, 5.70000e-01],\n", - " [ 1.52000e+02, 2.31040e+04, 3.70000e-01, 2.40000e-01, 2.40000e-01, 3.70000e-01, 2.40000e-01],\n", - " [ 0.00000e+00, 0.00000e+00, 1.40000e-01, 1.40000e-01, 1.40000e-01, 1.40000e-01, 1.40000e-01],\n", - " [ 4.74000e+02, 2.24676e+05, 2.80000e-01, 2.80000e-01, 2.80000e-01, 2.80000e-01, 2.80000e-01]],\n", - " dtype=np.float32)).all().all()" + "assert (\n", + " (\n", + " arr1\n", + " == np.array(\n", + " [\n", + " [\n", + " 0.00000e00,\n", + " 0.00000e00,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " 1.20000e-01,\n", + " ],\n", + " [\n", + " 4.74000e02,\n", + " 2.24676e05,\n", + " 2.40000e-01,\n", + " 3.70000e-01,\n", + " 3.70000e-01,\n", + " 2.40000e-01,\n", + " 3.70000e-01,\n", + " ],\n", + " [\n", + " 4.74000e02,\n", + " 2.24676e05,\n", + " 4.40000e-01,\n", + " 5.70000e-01,\n", + " 5.70000e-01,\n", + " 4.40000e-01,\n", + " 5.70000e-01,\n", + " ],\n", + " [\n", + " 1.52000e02,\n", + " 2.31040e04,\n", + " 3.70000e-01,\n", + " 2.40000e-01,\n", + " 2.40000e-01,\n", + " 3.70000e-01,\n", + " 2.40000e-01,\n", + " ],\n", + " [\n", + " 0.00000e00,\n", + " 0.00000e00,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " 1.40000e-01,\n", + " ],\n", + " [\n", + " 4.74000e02,\n", + " 2.24676e05,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " 2.80000e-01,\n", + " ],\n", + " ],\n", + " dtype=np.float32,\n", + " )\n", + " )\n", + " .all()\n", + " .all()\n", + ")" ] }, { @@ -686,11 +755,13 @@ "outputs": [], "source": [ "# TEST\n", - "assert skims1.digital_encoding.baggage(['WLK_LOC_WLK_FAR']) == {'joined_0_offsets'}\n", - "assert (skims1.iat(\n", - " otaz=[0,1,2], dtaz=[0,0,0], time_period=[1,1,1],\n", - " _name='WLK_LOC_WLK_FAR'\n", - ").to_series() == [0,152,474]).all()" + "assert skims1.digital_encoding.baggage([\"WLK_LOC_WLK_FAR\"]) == {\"joined_0_offsets\"}\n", + "assert (\n", + " skims1.iat(\n", + " otaz=[0, 1, 2], dtaz=[0, 0, 0], time_period=[1, 1, 1], _name=\"WLK_LOC_WLK_FAR\"\n", + " ).to_series()\n", + " == [0, 152, 474]\n", + ").all()" ] }, { @@ -720,8 +791,10 @@ "outputs": [], "source": [ "hh = sh.example_data.get_households()\n", - "hh[\"income_grp\"] = pd.cut(hh.income, bins=[-np.inf,30000,60000,np.inf], labels=['Low', \"Mid\", \"High\"])\n", - "hh = hh[[\"income\",\"income_grp\"]]\n", + "hh[\"income_grp\"] = pd.cut(\n", + " hh.income, bins=[-np.inf, 30000, 60000, np.inf], labels=[\"Low\", \"Mid\", \"High\"]\n", + ")\n", + "hh = hh[[\"income\", \"income_grp\"]]\n", "hh.head()" ] }, @@ -754,7 +827,7 @@ }, "outputs": [], "source": [ - "hh_dataset = sh.dataset.construct(hh[[\"income\",\"income_grp\"]])\n", + "hh_dataset = sh.dataset.construct(hh[[\"income\", \"income_grp\"]])\n", "hh_dataset" ] }, @@ -793,9 +866,12 @@ "source": [ "# TESTING\n", "assert hh_dataset[\"income_grp\"].dtype == \"int8\"\n", - "assert hh_dataset[\"income_grp\"].digital_encoding.keys() == {'dictionary', 'ordered'}\n", - "assert all(hh_dataset[\"income_grp\"].digital_encoding['dictionary'] == np.array(['Low', 'Mid', 'High'], dtype='= packaging.version.parse(\"0.20.2\")" ] }, @@ -84,7 +87,7 @@ "source": [ "# TEST households content\n", "assert len(households) == 5000\n", - "assert \"income\" in households \n", + "assert \"income\" in households\n", "assert households.index.name == \"HHID\"" ] }, @@ -112,7 +115,7 @@ "source": [ "assert len(persons) == 8212\n", "assert \"household_id\" in persons\n", - "assert persons.index.name == 'PERID'" + "assert persons.index.name == \"PERID\"" ] }, { @@ -178,13 +181,17 @@ "source": [ "def random_tours(n_tours=100_000, seed=42):\n", " rng = np.random.default_rng(seed)\n", - " n_zones = skims.dims['dtaz']\n", - " return pd.DataFrame({\n", - " 'PERID': rng.choice(persons.index, size=n_tours),\n", - " 'dest_taz_idx': rng.choice(n_zones, size=n_tours),\n", - " 'out_time_period': rng.choice(skims.time_period, size=n_tours),\n", - " 'in_time_period': rng.choice(skims.time_period, size=n_tours),\n", - " }).rename_axis(\"TOURIDX\")\n", + " n_zones = skims.dims[\"dtaz\"]\n", + " return pd.DataFrame(\n", + " {\n", + " \"PERID\": rng.choice(persons.index, size=n_tours),\n", + " \"dest_taz_idx\": rng.choice(n_zones, size=n_tours),\n", + " \"out_time_period\": rng.choice(skims.time_period, size=n_tours),\n", + " \"in_time_period\": rng.choice(skims.time_period, size=n_tours),\n", + " }\n", + " ).rename_axis(\"TOURIDX\")\n", + "\n", + "\n", "tours = random_tours()\n", "tours.head()" ] @@ -269,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "spec = pd.read_csv(StringIO(mini_spec), index_col='Label')\n", + "spec = pd.read_csv(StringIO(mini_spec), index_col=\"Label\")\n", "spec" ] }, @@ -286,7 +293,7 @@ "source": [ "# TEST check spec\n", "assert spec.index.name == \"Label\"\n", - "assert all(spec.columns == ['Expression', 'DRIVE', 'WALK', 'TRANSIT'])" + "assert all(spec.columns == [\"Expression\", \"DRIVE\", \"WALK\", \"TRANSIT\"])" ] }, { @@ -309,7 +316,7 @@ "metadata": {}, "outputs": [], "source": [ - "income_breakpoints = nb.typed.Dict.empty(nb.types.int32,nb.types.int32)\n", + "income_breakpoints = nb.typed.Dict.empty(nb.types.int32, nb.types.int32)\n", "income_breakpoints[0] = 15000\n", "income_breakpoints[1] = 30000\n", "income_breakpoints[2] = 60000\n", @@ -331,12 +338,12 @@ " \"tour.in_time_period @ dot_skims.time_period\",\n", " ),\n", " extra_vars={\n", - " 'shortwait': 3,\n", - " 'one': 1,\n", + " \"shortwait\": 3,\n", + " \"one\": 1,\n", " },\n", " aux_vars={\n", - " 'short_i_wait_mult': 0.75,\n", - " 'income_breakpoints': income_breakpoints,\n", + " \"short_i_wait_mult\": 0.75,\n", + " \"income_breakpoints\": income_breakpoints,\n", " },\n", ")" ] @@ -410,9 +417,9 @@ "outputs": [], "source": [ "# TEST\n", - "from pytest import approx\n", - "assert flow.tree.aux_vars['short_i_wait_mult'] == 0.75\n", - "assert flow.tree.aux_vars['income_breakpoints'][2] == 60000" + "\n", + "assert flow.tree.aux_vars[\"short_i_wait_mult\"] == 0.75\n", + "assert flow.tree.aux_vars[\"income_breakpoints\"][2] == 60000" ] }, { @@ -439,16 +446,21 @@ "# TEST utility data\n", "assert flow.check_cache_misses(fresh=False)\n", "actual = flow.load()\n", - "expected = np.array([[ 9.4 , 16.9572 , 4.5 , 0. , 1. ],\n", - " [ 9.32 , 14.3628 , 4.5 , 1. , 1. ],\n", - " [ 7.62 , 11.0129 , 4.5 , 1. , 1. ],\n", - " [ 4.25 , 7.6692 , 2.50065 , 0. , 1. ],\n", - " [ 6.16 , 8.2186 , 3.387825, 0. , 1. ],\n", - " [ 4.86 , 4.9288 , 4.5 , 0. , 1. ],\n", - " [ 1.07 , 0. , 0. , 0. , 1. ],\n", - " [ 8.52 , 11.615499, 3.260325, 0. , 1. ],\n", - " [ 11.74 , 16.2798 , 3.440325, 0. , 1. ],\n", - " [ 10.48 , 13.3974 , 3.942825, 0. , 1. ]], dtype=np.float32)\n", + "expected = np.array(\n", + " [\n", + " [9.4, 16.9572, 4.5, 0.0, 1.0],\n", + " [9.32, 14.3628, 4.5, 1.0, 1.0],\n", + " [7.62, 11.0129, 4.5, 1.0, 1.0],\n", + " [4.25, 7.6692, 2.50065, 0.0, 1.0],\n", + " [6.16, 8.2186, 3.387825, 0.0, 1.0],\n", + " [4.86, 4.9288, 4.5, 0.0, 1.0],\n", + " [1.07, 0.0, 0.0, 0.0, 1.0],\n", + " [8.52, 11.615499, 3.260325, 0.0, 1.0],\n", + " [11.74, 16.2798, 3.440325, 0.0, 1.0],\n", + " [10.48, 13.3974, 3.942825, 0.0, 1.0],\n", + " ],\n", + " dtype=np.float32,\n", + ")\n", "\n", "np.testing.assert_array_almost_equal(actual[:5], expected[:5])\n", "np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])\n", @@ -483,8 +495,11 @@ "# TEST compile flags\n", "flow.load(compile_watch=False)\n", "import pytest\n", + "\n", "with pytest.raises(AttributeError):\n", - " flow.compiled_recently # attribute does not exist if compile_watch flag is off" + " compiled_recently = (\n", + " flow.compiled_recently\n", + " ) # attribute does not exist if compile_watch flag is off" ] }, { @@ -542,8 +557,9 @@ "source": [ "# TEST\n", "from pytest import approx\n", - "assert tree_2.aux_vars['short_i_wait_mult'] == 0.75\n", - "assert tree_2.aux_vars['income_breakpoints'][2] == approx(60000)" + "\n", + "assert tree_2.aux_vars[\"short_i_wait_mult\"] == 0.75\n", + "assert tree_2.aux_vars[\"income_breakpoints\"][2] == approx(60000)" ] }, { @@ -565,18 +581,23 @@ "source": [ "# TEST that aux_vars also work with arrays\n", "tree_a = tree_2.replace_datasets(tour=tours)\n", - "tree_a.aux_vars['income_breakpoints'] = np.asarray([1,2,60000])\n", + "tree_a.aux_vars[\"income_breakpoints\"] = np.asarray([1, 2, 60000])\n", "actual = flow.load(tree_a)\n", - "expected = np.array([[ 9.4 , 16.9572 , 4.5 , 0. , 1. ],\n", - " [ 9.32 , 14.3628 , 4.5 , 1. , 1. ],\n", - " [ 7.62 , 11.0129 , 4.5 , 1. , 1. ],\n", - " [ 4.25 , 7.6692 , 2.50065 , 0. , 1. ],\n", - " [ 6.16 , 8.2186 , 3.387825, 0. , 1. ],\n", - " [ 4.86 , 4.9288 , 4.5 , 0. , 1. ],\n", - " [ 1.07 , 0. , 0. , 0. , 1. ],\n", - " [ 8.52 , 11.615499, 3.260325, 0. , 1. ],\n", - " [ 11.74 , 16.2798 , 3.440325, 0. , 1. ],\n", - " [ 10.48 , 13.3974 , 3.942825, 0. , 1. ]], dtype=np.float32)\n", + "expected = np.array(\n", + " [\n", + " [9.4, 16.9572, 4.5, 0.0, 1.0],\n", + " [9.32, 14.3628, 4.5, 1.0, 1.0],\n", + " [7.62, 11.0129, 4.5, 1.0, 1.0],\n", + " [4.25, 7.6692, 2.50065, 0.0, 1.0],\n", + " [6.16, 8.2186, 3.387825, 0.0, 1.0],\n", + " [4.86, 4.9288, 4.5, 0.0, 1.0],\n", + " [1.07, 0.0, 0.0, 0.0, 1.0],\n", + " [8.52, 11.615499, 3.260325, 0.0, 1.0],\n", + " [11.74, 16.2798, 3.440325, 0.0, 1.0],\n", + " [10.48, 13.3974, 3.942825, 0.0, 1.0],\n", + " ],\n", + " dtype=np.float32,\n", + ")\n", "\n", "np.testing.assert_array_almost_equal(actual[:5], expected[:5])\n", "np.testing.assert_array_almost_equal(actual[-5:], expected[-5:])\n", @@ -617,15 +638,20 @@ "# TEST df\n", "assert len(df) == len(tours)\n", "pd.testing.assert_index_equal(\n", - " df.columns, \n", - " pd.Index(['Drive Time', 'Transit IVT', 'Transit Wait Time', 'Income', 'Constant']),\n", + " df.columns,\n", + " pd.Index([\"Drive Time\", \"Transit IVT\", \"Transit Wait Time\", \"Income\", \"Constant\"]),\n", ")\n", - "expected_df_head = pd.read_csv(StringIO(''',Drive Time,Transit IVT,Transit Wait Time,Income,Constant\n", + "expected_df_head = pd.read_csv(\n", + " StringIO(\n", + " \"\"\",Drive Time,Transit IVT,Transit Wait Time,Income,Constant\n", "0,9.4,16.9572,4.5,0.0,1.0\n", "1,9.32,14.3628,4.5,1.0,1.0\n", "2,7.62,11.0129,4.5,1.0,1.0\n", "3,4.25,7.6692,2.50065,0.0,1.0\n", - "4,6.16,8.2186,3.387825,0.0,1.0'''), index_col=0).astype(np.float32)\n", + "4,6.16,8.2186,3.387825,0.0,1.0\"\"\"\n", + " ),\n", + " index_col=0,\n", + ").astype(np.float32)\n", "pd.testing.assert_frame_equal(df.head(), expected_df_head)" ] }, @@ -651,7 +677,7 @@ "outputs": [], "source": [ "x = flow.load()\n", - "b = spec.iloc[:,1:].fillna(0).astype(np.float32).values\n", + "b = spec.iloc[:, 1:].fillna(0).astype(np.float32).values\n", "np.dot(x, b)" ] }, @@ -672,7 +698,17 @@ "metadata": {}, "outputs": [], "source": [ - "%time u = flow.dot(b)\n", + "%time flow.dot(b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5776822fb0889df", + "metadata": {}, + "outputs": [], + "source": [ + "u = flow.dot(b)\n", "u" ] }, @@ -747,8 +783,7 @@ "outputs": [], "source": [ "B = xr.DataArray(\n", - " spec.iloc[:,1:].fillna(0).astype(np.float32), \n", - " dims=('expressions','modes')\n", + " spec.iloc[:, 1:].fillna(0).astype(np.float32), dims=(\"expressions\", \"modes\")\n", ")\n", "flow.dot_dataarray(B, source=tree_2)" ] @@ -788,6 +823,16 @@ "was computed for each chosen alternative. " ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "d54d71021951470b", + "metadata": {}, + "outputs": [], + "source": [ + "choices, choice_probs = flow.logit_draws(b, draws)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -809,6 +854,16 @@ "milliseconds more time than just computing the utilities." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "eec9ebd14ff646eb", + "metadata": {}, + "outputs": [], + "source": [ + "choices2, choice_probs2 = flow.logit_draws(b, draws, source=tree_2)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -868,9 +923,9 @@ "source": [ "# TEST mnl choices\n", "uz = np.exp(flow.dot(b))\n", - "uz = uz / uz.sum(1)[:,None]\n", + "uz = uz / uz.sum(1)[:, None]\n", "np.testing.assert_array_almost_equal(\n", - " uz[range(uz.shape[0]),choices.ravel()],\n", + " uz[range(uz.shape[0]), choices.ravel()],\n", " choice_probs.ravel(),\n", ")" ] @@ -926,12 +981,12 @@ "\"\"\"\n", "\n", "import yaml\n", + "\n", "from sharrow.nested_logit import construct_nesting_tree\n", "\n", - "nesting_settings = yaml.safe_load(nesting_settings)['NESTS']\n", + "nesting_settings = yaml.safe_load(nesting_settings)[\"NESTS\"]\n", "nest_tree = construct_nesting_tree(\n", - " alternatives=spec.columns[1:],\n", - " nesting_settings=nesting_settings\n", + " alternatives=spec.columns[1:], nesting_settings=nesting_settings\n", ")" ] }, @@ -965,7 +1020,9 @@ "metadata": {}, "outputs": [], "source": [ - "nesting = nest_tree.as_arrays(trim=True, parameter_dict={'coef_nest_motor': 0.5, 'coef_nest_root': 1.0})" + "nesting = nest_tree.as_arrays(\n", + " trim=True, parameter_dict={\"coef_nest_motor\": 0.5, \"coef_nest_root\": 1.0}\n", + ")" ] }, { @@ -1023,8 +1080,11 @@ "source": [ "# TEST devolve NL to MNL\n", "choices_nl_1, choice_probs_nl_1 = flow.logit_draws(\n", - " b, draws, \n", - " nesting=nest_tree.as_arrays(trim=True, parameter_dict={'coef_nest_motor': 1.0, 'coef_nest_root': 1.0}),\n", + " b,\n", + " draws,\n", + " nesting=nest_tree.as_arrays(\n", + " trim=True, parameter_dict={\"coef_nest_motor\": 1.0, \"coef_nest_root\": 1.0}\n", + " ),\n", ")\n", "assert (choices_nl_1 == choices).all()\n", "assert choice_probs == approx(choice_probs_nl_1)" @@ -1055,23 +1115,28 @@ "metadata": {}, "outputs": [], "source": [ - "# TEST \n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=1)\n", + "# TEST\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=1\n", + ")\n", "assert _ch is None\n", "assert _pr is None\n", "assert _pc is None\n", "assert _ls.size == 100000\n", "np.testing.assert_array_almost_equal(\n", - " _ls[:5],\n", - " [ 0.532791, 0.490935, 0.557529, 0.556371, 0.54812 ]\n", + " _ls[:5], [0.532791, 0.490935, 0.557529, 0.556371, 0.54812]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _ls[-5:],\n", - " [ 0.452682, 0.465422, 0.554312, 0.525064, 0.515226 ]\n", + " _ls[-5:], [0.452682, 0.465422, 0.554312, 0.525064, 0.515226]\n", ")\n", "\n", "_ch, _pr, _pc, _ls = flow.logit_draws(\n", - " b, draws, source=tree_2, nesting=nesting, logsums=1, as_dataarray=True,\n", + " b,\n", + " draws,\n", + " source=tree_2,\n", + " nesting=nesting,\n", + " logsums=1,\n", + " as_dataarray=True,\n", ")\n", "assert _ch is None\n", "assert _pr is None\n", @@ -1091,7 +1156,9 @@ "# TEST masking\n", "masker = np.zeros(draws.shape, dtype=np.int8)\n", "masker[::2] = 1\n", - "_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=1, mask=masker)\n", + "_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=1, mask=masker\n", + ")\n", "\n", "assert _ls_m == approx(np.where(masker, _ls, 0))\n", "assert (_ch_m, _pr_m, _pc_m) == (None, None, None)" @@ -1126,37 +1193,31 @@ "metadata": {}, "outputs": [], "source": [ - "# TEST \n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=2)\n", + "# TEST\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=2\n", + ")\n", "assert _ch.size == 100000\n", "assert _pr.size == 100000\n", "assert _pc is None\n", "assert _ls.size == 100000\n", + "np.testing.assert_array_almost_equal(_ch[:5], [1, 2, 1, 1, 1])\n", + "np.testing.assert_array_almost_equal(_ch[-5:], [0, 1, 0, 1, 0])\n", "np.testing.assert_array_almost_equal(\n", - " _ch[:5],\n", - " [ 1, 2, 1, 1, 1 ]\n", - ")\n", - "np.testing.assert_array_almost_equal(\n", - " _ch[-5:],\n", - " [ 0, 1, 0, 1, 0 ]\n", + " _pr[:5], [0.393454, 0.16956, 0.38384, 0.384285, 0.387469]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _pr[:5],\n", - " [ 0.393454, 0.16956 , 0.38384 , 0.384285, 0.387469 ]\n", + " _pr[-5:], [0.503606, 0.420874, 0.478898, 0.396506, 0.468742]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _pr[-5:],\n", - " [ 0.503606, 0.420874, 0.478898, 0.396506, 0.468742 ]\n", + " _ls[:5], [0.532791, 0.490935, 0.557529, 0.556371, 0.54812]\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " _ls[:5],\n", - " [ 0.532791, 0.490935, 0.557529, 0.556371, 0.54812 ]\n", + " _ls[-5:], [0.452682, 0.465422, 0.554312, 0.525064, 0.515226]\n", ")\n", - "np.testing.assert_array_almost_equal(\n", - " _ls[-5:],\n", - " [ 0.452682, 0.465422, 0.554312, 0.525064, 0.515226 ]\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True\n", ")\n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True)\n", "assert _ch.size == 100000\n", "assert _ch.dims == (\"TOURIDX\",)\n", "assert _ch.shape == (100000,)\n", @@ -1177,23 +1238,33 @@ "source": [ "# TEST\n", "draws_many = np.random.default_rng(42).random(size=(tree.shape[0], 5))\n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True)\n", - "assert _ch.dims == ('TOURIDX', 'DRAW')\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True\n", + ")\n", + "assert _ch.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _ch.shape == (100000, 5)\n", - "assert _pr.dims == ('TOURIDX', 'DRAW')\n", + "assert _pr.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _pr.shape == (100000, 5)\n", - "assert _ls.dims == ('TOURIDX', )\n", - "assert _ls.shape == (100000, )\n", + "assert _ls.dims == (\"TOURIDX\",)\n", + "assert _ls.shape == (100000,)\n", "assert _pc is None\n", "\n", - "_ch, _pr, _pc, _ls = flow.logit_draws(b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True, pick_counted=True)\n", - "assert _ch.dims == ('TOURIDX', 'DRAW')\n", + "_ch, _pr, _pc, _ls = flow.logit_draws(\n", + " b,\n", + " draws_many,\n", + " source=tree_2,\n", + " nesting=nesting,\n", + " logsums=2,\n", + " as_dataarray=True,\n", + " pick_counted=True,\n", + ")\n", + "assert _ch.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _ch.shape == (100000, 5)\n", - "assert _pr.dims == ('TOURIDX', 'DRAW')\n", + "assert _pr.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _pr.shape == (100000, 5)\n", - "assert _ls.dims == ('TOURIDX', )\n", - "assert _ls.shape == (100000, )\n", - "assert _pc.dims == ('TOURIDX', 'DRAW')\n", + "assert _ls.dims == (\"TOURIDX\",)\n", + "assert _ls.shape == (100000,)\n", + "assert _pc.dims == (\"TOURIDX\", \"DRAW\")\n", "assert _pc.shape == (100000, 5)" ] }, @@ -1209,7 +1280,14 @@ "masker[::3] = 1\n", "\n", "_ch_m, _pr_m, _pc_m, _ls_m = flow.logit_draws(\n", - " b, draws_many, source=tree_2, nesting=nesting, logsums=2, as_dataarray=True, mask=masker, pick_counted=True\n", + " b,\n", + " draws_many,\n", + " source=tree_2,\n", + " nesting=nesting,\n", + " logsums=2,\n", + " as_dataarray=True,\n", + " mask=masker,\n", + " pick_counted=True,\n", ")\n", "\n", "assert (_ch_m.values == (np.where(np.expand_dims(masker, -1), _ch, -1))).all()\n", @@ -1241,8 +1319,10 @@ "metadata": {}, "outputs": [], "source": [ - "tour_by_dest = tree.subspaces['tour']\n", - "tour_by_dest = tour_by_dest.assign_coords({'CAND_DEST': xr.DataArray(np.arange(25), dims='CAND_DEST')})\n", + "tour_by_dest = tree.subspaces[\"tour\"]\n", + "tour_by_dest = tour_by_dest.assign_coords(\n", + " {\"CAND_DEST\": xr.DataArray(np.arange(25), dims=\"CAND_DEST\")}\n", + ")\n", "tour_by_dest" ] }, @@ -1278,14 +1358,14 @@ " \"tour.in_time_period @ dot_skims.time_period\",\n", " ),\n", " extra_vars={\n", - " 'shortwait': 3,\n", - " 'one': 1,\n", + " \"shortwait\": 3,\n", + " \"one\": 1,\n", " },\n", " aux_vars={\n", - " 'short_i_wait_mult': 0.75,\n", - " 'income_breakpoints': income_breakpoints,\n", + " \"short_i_wait_mult\": 0.75,\n", + " \"income_breakpoints\": income_breakpoints,\n", " },\n", - " dim_order=('TOURIDX', 'CAND_DEST')\n", + " dim_order=(\"TOURIDX\", \"CAND_DEST\"),\n", ")\n", "wide_flow = wide_tree.setup_flow(spec.Expression)" ] @@ -1297,7 +1377,7 @@ "metadata": {}, "outputs": [], "source": [ - "%time wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch=\"simple\")[-1]" + "wide_logsums = wide_flow.logit_draws(b, logsums=1, compile_watch=\"simple\")[-1]" ] }, { @@ -1320,20 +1400,30 @@ "source": [ "# TEST\n", "np.testing.assert_array_almost_equal(\n", - " wide_logsums[:5,:5],\n", - " np.array([[ 0.759222, 0.75862 , 0.744936, 0.758251, 0.737007],\n", - " [ 0.671698, 0.671504, 0.663015, 0.661482, 0.667133],\n", - " [ 0.670188, 0.678498, 0.687647, 0.691152, 0.715783],\n", - " [ 0.760743, 0.769123, 0.763733, 0.784487, 0.802356],\n", - " [ 0.73474 , 0.743051, 0.751439, 0.754731, 0.778121]], dtype=np.float32)\n", + " wide_logsums[:5, :5],\n", + " np.array(\n", + " [\n", + " [0.759222, 0.75862, 0.744936, 0.758251, 0.737007],\n", + " [0.671698, 0.671504, 0.663015, 0.661482, 0.667133],\n", + " [0.670188, 0.678498, 0.687647, 0.691152, 0.715783],\n", + " [0.760743, 0.769123, 0.763733, 0.784487, 0.802356],\n", + " [0.73474, 0.743051, 0.751439, 0.754731, 0.778121],\n", + " ],\n", + " dtype=np.float32,\n", + " ),\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " wide_logsums[-5:,-5:],\n", - " np.array([[ 0.719523, 0.755152, 0.739368, 0.762664, 0.764388],\n", - " [ 0.740303, 0.678783, 0.649964, 0.694407, 0.681555],\n", - " [ 0.758865, 0.663663, 0.637266, 0.673351, 0.65875 ],\n", - " [ 0.765125, 0.706478, 0.676878, 0.717814, 0.713912],\n", - " [ 0.73348 , 0.683626, 0.647698, 0.69146 , 0.673006]], dtype=np.float32)\n", + " wide_logsums[-5:, -5:],\n", + " np.array(\n", + " [\n", + " [0.719523, 0.755152, 0.739368, 0.762664, 0.764388],\n", + " [0.740303, 0.678783, 0.649964, 0.694407, 0.681555],\n", + " [0.758865, 0.663663, 0.637266, 0.673351, 0.65875],\n", + " [0.765125, 0.706478, 0.676878, 0.717814, 0.713912],\n", + " [0.73348, 0.683626, 0.647698, 0.69146, 0.673006],\n", + " ],\n", + " dtype=np.float32,\n", + " ),\n", ")" ] }, @@ -1346,8 +1436,8 @@ "source": [ "# TEST\n", "np.testing.assert_array_almost_equal(\n", - " wide_logsums[np.arange(len(tours)), tours['dest_taz_idx'].to_numpy()],\n", - " flow.logit_draws(b, logsums=1)[-1]\n", + " wide_logsums[np.arange(len(tours)), tours[\"dest_taz_idx\"].to_numpy()],\n", + " flow.logit_draws(b, logsums=1)[-1],\n", ")" ] }, @@ -1359,7 +1449,9 @@ "outputs": [], "source": [ "# TEST\n", - "wide_logsums_ = wide_flow.logit_draws(b, logsums=1, compile_watch=True, as_dataarray=True)[-1]\n", + "wide_logsums_ = wide_flow.logit_draws(\n", + " b, logsums=1, compile_watch=True, as_dataarray=True\n", + ")[-1]\n", "assert wide_logsums_.dims == (\"TOURIDX\", \"CAND_DEST\")\n", "assert wide_logsums_.shape == (100000, 25)" ] @@ -1392,7 +1484,9 @@ "source": [ "# TEST\n", "wide_draws = np.random.default_rng(42).random(size=wide_tree.shape + (2,))\n", - "wide_logsums_plus = wide_flow.logit_draws(b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws)\n", + "wide_logsums_plus = wide_flow.logit_draws(\n", + " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws\n", + ")\n", "assert wide_logsums_plus[0].dims == (\"TOURIDX\", \"CAND_DEST\", \"DRAW\")\n", "assert wide_logsums_plus[0].shape == (100000, 25, 2)\n", "assert wide_logsums_plus[3].dims == (\"TOURIDX\", \"CAND_DEST\")\n", @@ -1418,8 +1512,12 @@ "assert wide_logsums_mask[3].dims == (\"TOURIDX\", \"CAND_DEST\")\n", "assert wide_logsums_mask[3].shape == (100000, 25)\n", "\n", - "assert (wide_logsums_plus[0].where(np.expand_dims(mask, -1), -1) == wide_logsums_mask[0]).all()\n", - "assert (wide_logsums_plus[1].where(np.expand_dims(mask, -1), 0) == wide_logsums_mask[1]).all()\n", + "assert (\n", + " wide_logsums_plus[0].where(np.expand_dims(mask, -1), -1) == wide_logsums_mask[0]\n", + ").all()\n", + "assert (\n", + " wide_logsums_plus[1].where(np.expand_dims(mask, -1), 0) == wide_logsums_mask[1]\n", + ").all()\n", "assert (wide_logsums_plus[3].where(mask, 0) == wide_logsums_mask[3]).all()" ] }, @@ -1431,17 +1529,30 @@ "outputs": [], "source": [ "# TEST masking performance\n", - "import timeit, warnings\n", + "import timeit\n", + "import warnings\n", + "\n", "with warnings.catch_warnings():\n", " warnings.simplefilter(\"error\")\n", - " masked_time = timeit.timeit(lambda: wide_flow.logit_draws(\n", - " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws, mask=mask\n", - " ), number=1)\n", - " raw_time = timeit.timeit(lambda: wide_flow.logit_draws(\n", - " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws\n", - " ), number=1)\n", + " masked_time = timeit.timeit(\n", + " lambda: wide_flow.logit_draws(\n", + " b,\n", + " logsums=2,\n", + " compile_watch=True,\n", + " as_dataarray=True,\n", + " draws=wide_draws,\n", + " mask=mask,\n", + " ),\n", + " number=1,\n", + " )\n", + " raw_time = timeit.timeit(\n", + " lambda: wide_flow.logit_draws(\n", + " b, logsums=2, compile_watch=True, as_dataarray=True, draws=wide_draws\n", + " ),\n", + " number=1,\n", + " )\n", "assert masked_time * 2 < raw_time # generous buffer, should be nearly 7 times faster\n", - "assert len(wide_flow.cache_misses['_imnl_plus1d']) == 3" + "assert len(wide_flow.cache_misses[\"_imnl_plus1d\"]) == 3" ] } ], diff --git a/docs/walkthrough/sparse.ipynb b/docs/walkthrough/sparse.ipynb index 73565bf..dc4415d 100644 --- a/docs/walkthrough/sparse.ipynb +++ b/docs/walkthrough/sparse.ipynb @@ -17,7 +17,7 @@ "source": [ "import numpy as np\n", "import pandas as pd\n", - "import xarray as xr\n", + "\n", "import sharrow as sh" ] }, @@ -106,10 +106,10 @@ "outputs": [], "source": [ "skims.redirection.set(\n", - " maz_taz, \n", - " map_to='otaz', \n", + " maz_taz,\n", + " map_to=\"otaz\",\n", " name=\"omaz\",\n", - " map_also={'dtaz': \"dmaz\"}, \n", + " map_also={\"dtaz\": \"dmaz\"},\n", ")" ] }, @@ -141,9 +141,9 @@ "outputs": [], "source": [ "skims.redirection.sparse_blender(\n", - " 'DISTWALK', \n", - " maz_to_maz_walk.OMAZ, \n", - " maz_to_maz_walk.DMAZ, \n", + " \"DISTWALK\",\n", + " maz_to_maz_walk.OMAZ,\n", + " maz_to_maz_walk.DMAZ,\n", " maz_to_maz_walk.DISTWALK,\n", " max_blend_distance=1.0,\n", " index=maz_taz.index,\n", @@ -170,10 +170,12 @@ "metadata": {}, "outputs": [], "source": [ - "trips = pd.DataFrame({\n", - " 'orig_maz': [100, 100, 100, 200, 200],\n", - " 'dest_maz': [100, 101, 103, 201, 202],\n", - "})\n", + "trips = pd.DataFrame(\n", + " {\n", + " \"orig_maz\": [100, 100, 100, 200, 200],\n", + " \"dest_maz\": [100, 101, 103, 201, 202],\n", + " }\n", + ")\n", "trips" ] }, @@ -199,7 +201,7 @@ " relationships=(\n", " \"base.orig_maz @ skims.omaz\",\n", " \"base.dest_maz @ skims.dmaz\",\n", - " )\n", + " ),\n", ")" ] }, @@ -218,9 +220,12 @@ "metadata": {}, "outputs": [], "source": [ - "flow = tree.setup_flow({\n", - " 'plain_distance': 'DISTWALK',\n", - "}, boundscheck=True)" + "flow = tree.setup_flow(\n", + " {\n", + " \"plain_distance\": \"DISTWALK\",\n", + " },\n", + " boundscheck=True,\n", + ")" ] }, { @@ -252,15 +257,20 @@ "source": [ "# TEST\n", "from pytest import approx\n", + "\n", "sparse_dat = np.array([0.01, 0.2, np.nan, 3.2, np.nan])\n", - "dense_dat = np.array([0.12,0.12,0.12,0.17,0.17])\n", - "def blend(s,d, max_s):\n", + "dense_dat = np.array([0.12, 0.12, 0.12, 0.17, 0.17])\n", + "\n", + "\n", + "def blend(s, d, max_s):\n", " out = np.zeros_like(d)\n", - " ratio = s/max_s\n", - " out = d*ratio + s*(1-ratio)\n", - " out = np.where(s>max_s, d, out)\n", + " ratio = s / max_s\n", + " out = d * ratio + s * (1 - ratio)\n", + " out = np.where(s > max_s, d, out)\n", " out = np.where(np.isnan(s), d, out)\n", " return out\n", + "\n", + "\n", "assert blend(sparse_dat, dense_dat, 1.0) == approx(flow.load().ravel())" ] }, @@ -279,11 +289,13 @@ "metadata": {}, "outputs": [], "source": [ - "flow2 = tree.setup_flow({\n", - " 'plain_distance': 'DISTWALK',\n", - " 'clip_distance': 'DISTWALK.clip(upper=0.15)',\n", - " 'square_distance': 'DISTWALK**2',\n", - "})" + "flow2 = tree.setup_flow(\n", + " {\n", + " \"plain_distance\": \"DISTWALK\",\n", + " \"clip_distance\": \"DISTWALK.clip(upper=0.15)\",\n", + " \"square_distance\": \"DISTWALK**2\",\n", + " }\n", + ")" ] }, { @@ -304,12 +316,17 @@ "outputs": [], "source": [ "# TEST\n", - "assert flow2.load_dataframe().values == approx(np.array([\n", - " [ 1.1100e-02, 1.1100e-02, 1.2321e-04],\n", - " [ 1.8400e-01, 1.5000e-01, 3.3856e-02],\n", - " [ 1.2000e-01, 1.2000e-01, 1.4400e-02],\n", - " [ 1.7000e-01, 1.5000e-01, 2.8900e-02],\n", - " [ 1.7000e-01, 1.5000e-01, 2.8900e-02]], dtype=np.float32)\n", + "assert flow2.load_dataframe().values == approx(\n", + " np.array(\n", + " [\n", + " [1.1100e-02, 1.1100e-02, 1.2321e-04],\n", + " [1.8400e-01, 1.5000e-01, 3.3856e-02],\n", + " [1.2000e-01, 1.2000e-01, 1.4400e-02],\n", + " [1.7000e-01, 1.5000e-01, 2.8900e-02],\n", + " [1.7000e-01, 1.5000e-01, 2.8900e-02],\n", + " ],\n", + " dtype=np.float32,\n", + " )\n", ")" ] }, @@ -340,7 +357,7 @@ "skims.at(\n", " omaz=trips.orig_maz,\n", " dmaz=trips.dest_maz,\n", - " _names=['DIST', 'DISTWALK'],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", ")" ] }, @@ -355,24 +372,26 @@ "out = skims.at(\n", " omaz=trips.orig_maz,\n", " dmaz=trips.dest_maz,\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", + " out[\"DIST\"].to_numpy(), np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", ")\n", "\n", "from pytest import raises\n", + "\n", "with raises(NotImplementedError):\n", " skims.at(\n", " omaz=trips.orig_maz,\n", " dmaz=trips.dest_maz,\n", - " time_period=['AM', 'AM', 'AM', 'AM', 'AM'],\n", - " _names=['DIST', 'DISTWALK', 'SOV_TIME'], _load=True,\n", + " time_period=[\"AM\", \"AM\", \"AM\", \"AM\", \"AM\"],\n", + " _names=[\"DIST\", \"DISTWALK\", \"SOV_TIME\"],\n", + " _load=True,\n", " )" ] }, @@ -384,9 +403,9 @@ "outputs": [], "source": [ "skims.iat(\n", - " omaz=[ 0, 0, 0, 100, 100],\n", - " dmaz=[ 0, 1, 3, 101, 102],\n", - " _names=['DIST', 'DISTWALK'],\n", + " omaz=[0, 0, 0, 100, 100],\n", + " dmaz=[0, 1, 3, 101, 102],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", ")" ] }, @@ -399,18 +418,18 @@ "source": [ "# TEST\n", "out = skims.iat(\n", - " omaz=[ 0, 0, 0, 100, 100],\n", - " dmaz=[ 0, 1, 3, 101, 102],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " omaz=[0, 0, 0, 100, 100],\n", + " dmaz=[0, 1, 3, 101, 102],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", + " out[\"DIST\"].to_numpy(), np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", ")\n", "np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", - ")\n" + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")" ] }, { @@ -430,9 +449,10 @@ "outputs": [], "source": [ "skims.at(\n", - " otaz=[1,1,1,16,16],\n", - " dtaz=[1,1,1,16,16],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " otaz=[1, 1, 1, 16, 16],\n", + " dtaz=[1, 1, 1, 16, 16],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")" ] }, @@ -444,9 +464,9 @@ "outputs": [], "source": [ "skims.at(\n", - " otaz=[1,1,1,16,16],\n", - " dtaz=[1,1,1,16,16],\n", - " _name='DISTWALK',\n", + " otaz=[1, 1, 1, 16, 16],\n", + " dtaz=[1, 1, 1, 16, 16],\n", + " _name=\"DISTWALK\",\n", ")" ] }, @@ -458,44 +478,47 @@ "outputs": [], "source": [ "# TEST\n", - "import sys\n", - "if sys.version_info > (3,8):\n", - " import secrets\n", - " token = \"skims-with-sparse\" + secrets.token_hex(5)\n", - " readback0 = skims.shm.to_shared_memory(token)\n", - " assert readback0.attrs == skims.attrs\n", - " readback = sh.Dataset.shm.from_shared_memory(token)\n", - " assert readback.attrs == skims.attrs\n", - " \n", - " out = readback.iat(\n", - " omaz=[ 0, 0, 0, 100, 100],\n", - " dmaz=[ 0, 1, 3, 101, 102],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", + "import secrets\n", "\n", - " out = readback.at(\n", - " omaz=trips.orig_maz,\n", - " dmaz=trips.dest_maz,\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DIST'].to_numpy(), \n", - " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", - " np.testing.assert_array_almost_equal(\n", - " out['DISTWALK'].to_numpy(), \n", - " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32)\n", - " )\n", - " \n", - " assert readback.redirection.blenders == {'DISTWALK': {'max_blend_distance': 1.0, 'blend_distance_name': None}}\n" + "token = \"skims-with-sparse\" + secrets.token_hex(5)\n", + "readback0 = skims.shm.to_shared_memory(token)\n", + "assert readback0.attrs == skims.attrs\n", + "readback = sh.Dataset.shm.from_shared_memory(token)\n", + "assert readback.attrs == skims.attrs\n", + "\n", + "out = readback.iat(\n", + " omaz=[0, 0, 0, 100, 100],\n", + " dmaz=[0, 1, 3, 101, 102],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DIST\"].to_numpy(),\n", + " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "\n", + "out = readback.at(\n", + " omaz=trips.orig_maz,\n", + " dmaz=trips.dest_maz,\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DIST\"].to_numpy(),\n", + " np.array([0.12, 0.12, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "np.testing.assert_array_almost_equal(\n", + " out[\"DISTWALK\"].to_numpy(),\n", + " np.array([0.0111, 0.184, 0.12, 0.17, 0.17], dtype=np.float32),\n", + ")\n", + "\n", + "assert readback.redirection.blenders == {\n", + " \"DISTWALK\": {\"max_blend_distance\": 1.0, \"blend_distance_name\": None}\n", + "}" ] }, { @@ -506,7 +529,9 @@ "outputs": [], "source": [ "# TEST\n", - "assert skims.redirection.blenders == {'DISTWALK': {'max_blend_distance': 1.0, 'blend_distance_name': None}}" + "assert skims.redirection.blenders == {\n", + " \"DISTWALK\": {\"max_blend_distance\": 1.0, \"blend_distance_name\": None}\n", + "}" ] }, { @@ -518,24 +543,28 @@ "source": [ "# TEST\n", "# reverse skims in sparse\n", - "flow3 = tree.setup_flow({\n", - " 'plain_distance': 'DISTWALK',\n", - " 'reverse_distance': 'skims.reverse(\"DISTWALK\")',\n", - "})\n", + "flow3 = tree.setup_flow(\n", + " {\n", + " \"plain_distance\": \"DISTWALK\",\n", + " \"reverse_distance\": 'skims.reverse(\"DISTWALK\")',\n", + " }\n", + ")\n", "\n", - "assert flow3.load() == approx(np.array([[ 0.0111, 0.0111],\n", - " [ 0.184 , 0.12 ],\n", - " [ 0.12 , 0.12 ],\n", - " [ 0.17 , 0.17 ],\n", - " [ 0.17 , 0.17 ]], dtype=np.float32))\n", + "assert flow3.load() == approx(\n", + " np.array(\n", + " [[0.0111, 0.0111], [0.184, 0.12], [0.12, 0.12], [0.17, 0.17], [0.17, 0.17]],\n", + " dtype=np.float32,\n", + " )\n", + ")\n", "\n", "z = skims.iat(\n", - " omaz=[ 0, 1, 3, 101, 102],\n", - " dmaz=[ 0, 0, 0, 100, 100],\n", - " _names=['DIST', 'DISTWALK'], _load=True,\n", + " omaz=[0, 1, 3, 101, 102],\n", + " dmaz=[0, 0, 0, 100, 100],\n", + " _names=[\"DIST\", \"DISTWALK\"],\n", + " _load=True,\n", ")\n", - "assert z['DISTWALK'].data == approx(np.array([ 0.0111, 0.12 , 0.12 , 0.17 , 0.17 ]))\n", - "assert z['DIST'].data == approx(np.array([ 0.12, 0.12 , 0.12 , 0.17 , 0.17 ]))" + "assert z[\"DISTWALK\"].data == approx(np.array([0.0111, 0.12, 0.12, 0.17, 0.17]))\n", + "assert z[\"DIST\"].data == approx(np.array([0.12, 0.12, 0.12, 0.17, 0.17]))" ] } ], diff --git a/docs/walkthrough/two-dim.ipynb b/docs/walkthrough/two-dim.ipynb index b7fb5b6..4b68245 100644 --- a/docs/walkthrough/two-dim.ipynb +++ b/docs/walkthrough/two-dim.ipynb @@ -19,10 +19,10 @@ "outputs": [], "source": [ "import numpy as np\n", - "import pandas as pd\n", "import xarray as xr\n", "\n", "import sharrow as sh\n", + "\n", "sh.__version__" ] }, @@ -39,6 +39,7 @@ "source": [ "# TEST check versions\n", "import packaging\n", + "\n", "assert packaging.version.parse(xr.__version__) >= packaging.version.parse(\"0.20.2\")" ] }, @@ -83,7 +84,7 @@ "source": [ "# test households content\n", "assert len(households) == 5000\n", - "assert \"income\" in households \n", + "assert \"income\" in households\n", "assert households.index.name == \"HHID\"" ] }, @@ -111,7 +112,7 @@ "source": [ "assert len(persons) == 8212\n", "assert \"household_id\" in persons\n", - "assert persons.index.name == 'PERID'" + "assert persons.index.name == \"PERID\"" ] }, { @@ -180,7 +181,7 @@ "metadata": {}, "outputs": [], "source": [ - "workers = persons.query(\"pemploy in [1,2]\").rename_axis(index='WORKERID')\n", + "workers = persons.query(\"pemploy in [1,2]\").rename_axis(index=\"WORKERID\")\n", "workers" ] }, @@ -215,8 +216,8 @@ "metadata": {}, "outputs": [], "source": [ - "skims_am = skims.sel(time_period='AM')\n", - "skims_pm = skims.sel(time_period='PM')" + "skims_am = skims.sel(time_period=\"AM\")\n", + "skims_pm = skims.sel(time_period=\"PM\")" ] }, { @@ -246,7 +247,7 @@ "outputs": [], "source": [ "base = sh.dataset.from_named_objects(\n", - " workers.index, \n", + " workers.index,\n", " landuse.index,\n", ")" ] @@ -279,7 +280,7 @@ "metadata": {}, "outputs": [], "source": [ - "tree = sh.DataTree(base=base, dim_order=('WORKERID', 'TAZ'))" + "tree = sh.DataTree(base=base, dim_order=(\"WORKERID\", \"TAZ\"))" ] }, { @@ -294,7 +295,7 @@ "outputs": [], "source": [ "# TEST tree_dest attributes\n", - "assert tree.dim_order == ('WORKERID', 'TAZ')\n", + "assert tree.dim_order == (\"WORKERID\", \"TAZ\")\n", "assert tree.shape == (4361, 25)" ] }, @@ -317,7 +318,7 @@ "metadata": {}, "outputs": [], "source": [ - "tree.add_dataset('person', persons, \"base.WORKERID @ person.PERID\")" + "tree.add_dataset(\"person\", persons, \"base.WORKERID @ person.PERID\")" ] }, { @@ -337,8 +338,8 @@ "metadata": {}, "outputs": [], "source": [ - "tree.add_dataset('landuse', landuse, \"base.TAZ @ landuse.TAZ\")\n", - "tree.add_dataset('hh', households, \"person.household_id @ hh.HHID\")" + "tree.add_dataset(\"landuse\", landuse, \"base.TAZ @ landuse.TAZ\")\n", + "tree.add_dataset(\"hh\", households, \"person.household_id @ hh.HHID\")" ] }, { @@ -360,17 +361,17 @@ "outputs": [], "source": [ "tree.add_dataset(\n", - " 'odskims', \n", - " skims_am, \n", + " \"odskims\",\n", + " skims_am,\n", " relationships=(\n", - " \"hh.TAZ @ odskims.otaz\", \n", + " \"hh.TAZ @ odskims.otaz\",\n", " \"base.TAZ @ odskims.dtaz\",\n", " ),\n", ")\n", "\n", "tree.add_dataset(\n", - " 'doskims', \n", - " skims_pm, \n", + " \"doskims\",\n", + " skims_pm,\n", " relationships=(\n", " \"base.TAZ @ doskims.otaz\",\n", " \"hh.TAZ @ doskims.dtaz\",\n", @@ -399,10 +400,10 @@ "outputs": [], "source": [ "definition = {\n", - " 'round_trip_dist': 'odskims.DIST + doskims.DIST',\n", - " 'round_trip_dist_first_mile': 'clip(odskims.DIST, 0, 1) + clip(doskims.DIST, 0, 1)',\n", - " 'round_trip_dist_addl_miles': 'clip(odskims.DIST-1, 0, None) + clip(doskims.DIST-1, 0, None)',\n", - " 'size_term': 'log(TOTPOP + 0.5*EMPRES)',\n", + " \"round_trip_dist\": \"odskims.DIST + doskims.DIST\",\n", + " \"round_trip_dist_first_mile\": \"clip(odskims.DIST, 0, 1) + clip(doskims.DIST, 0, 1)\",\n", + " \"round_trip_dist_addl_miles\": \"clip(odskims.DIST-1, 0, None) + clip(doskims.DIST-1, 0, None)\",\n", + " \"size_term\": \"log(TOTPOP + 0.5*EMPRES)\",\n", "}\n", "\n", "flow = tree.setup_flow(definition)" @@ -440,37 +441,46 @@ "source": [ "# TEST\n", "assert arr.shape == (4361, 25, 4)\n", - "expected = np.array([\n", - " [[ 0.61 , 0.61 , 0. , 4.610157],\n", - " [ 0.28 , 0.28 , 0. , 5.681878],\n", - " [ 0.56 , 0.56 , 0. , 6.368187],\n", - " [ 0.53 , 0.53 , 0. , 5.741399],\n", - " [ 1.23 , 1.23 , 0. , 7.17549 ]],\n", - "\n", - " [[ 1.19 , 1.19 , 0. , 4.610157],\n", - " [ 1.49 , 1.49 , 0. , 5.681878],\n", - " [ 1.88 , 1.85 , 0.03 , 6.368187],\n", - " [ 1.36 , 1.36 , 0. , 5.741399],\n", - " [ 1.93 , 1.93 , 0. , 7.17549 ]],\n", - "\n", - " [[ 1.19 , 1.19 , 0. , 4.610157],\n", - " [ 1.49 , 1.49 , 0. , 5.681878],\n", - " [ 1.88 , 1.85 , 0.03 , 6.368187],\n", - " [ 1.36 , 1.36 , 0. , 5.741399],\n", - " [ 1.93 , 1.93 , 0. , 7.17549 ]],\n", - "\n", - " [[ 0.24 , 0.24 , 0. , 4.610157],\n", - " [ 0.61 , 0.61 , 0. , 5.681878],\n", - " [ 1.01 , 1.01 , 0. , 6.368187],\n", - " [ 0.75 , 0.75 , 0. , 5.741399],\n", - " [ 1.38 , 1.38 , 0. , 7.17549 ]],\n", - "\n", - " [[ 0.61 , 0.61 , 0. , 4.610157],\n", - " [ 0.28 , 0.28 , 0. , 5.681878],\n", - " [ 0.56 , 0.56 , 0. , 6.368187],\n", - " [ 0.53 , 0.53 , 0. , 5.741399],\n", - " [ 1.23 , 1.23 , 0. , 7.17549 ]],\n", - "], dtype=np.float32)\n", + "expected = np.array(\n", + " [\n", + " [\n", + " [0.61, 0.61, 0.0, 4.610157],\n", + " [0.28, 0.28, 0.0, 5.681878],\n", + " [0.56, 0.56, 0.0, 6.368187],\n", + " [0.53, 0.53, 0.0, 5.741399],\n", + " [1.23, 1.23, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [1.19, 1.19, 0.0, 4.610157],\n", + " [1.49, 1.49, 0.0, 5.681878],\n", + " [1.88, 1.85, 0.03, 6.368187],\n", + " [1.36, 1.36, 0.0, 5.741399],\n", + " [1.93, 1.93, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [1.19, 1.19, 0.0, 4.610157],\n", + " [1.49, 1.49, 0.0, 5.681878],\n", + " [1.88, 1.85, 0.03, 6.368187],\n", + " [1.36, 1.36, 0.0, 5.741399],\n", + " [1.93, 1.93, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [0.24, 0.24, 0.0, 4.610157],\n", + " [0.61, 0.61, 0.0, 5.681878],\n", + " [1.01, 1.01, 0.0, 6.368187],\n", + " [0.75, 0.75, 0.0, 5.741399],\n", + " [1.38, 1.38, 0.0, 7.17549],\n", + " ],\n", + " [\n", + " [0.61, 0.61, 0.0, 4.610157],\n", + " [0.28, 0.28, 0.0, 5.681878],\n", + " [0.56, 0.56, 0.0, 6.368187],\n", + " [0.53, 0.53, 0.0, 5.741399],\n", + " [1.23, 1.23, 0.0, 7.17549],\n", + " ],\n", + " ],\n", + " dtype=np.float32,\n", + ")\n", "\n", "np.testing.assert_array_almost_equal(arr[:5, :5, :], expected)" ] @@ -529,10 +539,20 @@ "source": [ "# TEST\n", "assert isinstance(arr_pretty, xr.DataArray)\n", - "assert arr_pretty.dims == ('WORKERID', 'TAZ', 'expressions')\n", + "assert arr_pretty.dims == (\"WORKERID\", \"TAZ\", \"expressions\")\n", "assert arr_pretty.shape == (4361, 25, 4)\n", - "assert all(arr_pretty.expressions == np.array(['round_trip_dist', 'round_trip_dist_first_mile',\n", - " 'round_trip_dist_addl_miles', 'size_term'], dtype='=42,<64", - "wheel", - "setuptools_scm[toml]>=7.0", + "setuptools>=69", + "setuptools_scm>=8", ] build-backend = "setuptools.build_meta" +[project] +name = "sharrow" +requires-python = ">=3.9" +dynamic = ["version"] +dependencies = [ + "numpy >= 1.19", + "pandas >= 1.2", + "pyarrow", + "xarray", + "numba >= 0.51.2", + "numexpr", + "filelock", + "dask", + "networkx", +] +classifiers = [ + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +description = "numba for ActivitySim-style spec files" +readme = "README.md" +keywords = ["activitysim", "discrete choice"] + +[project.urls] +Documentation = "https://activitysim.github.io/sharrow/" +Repository = "https://github.com/activitysim/sharrow" + +[tool.setuptools] +packages = ["sharrow", "sharrow.utils"] + [tool.setuptools_scm] fallback_version = "1999" write_to = "sharrow/_version.py" -[tool.isort] -profile = "black" -skip_gitignore = true -float_to_top = true -default_section = "THIRDPARTY" -known_first_party = "sharrow" - [tool.ruff] -# Enable flake8-bugbear (`B`) and pyupgrade ('UP') rules. -select = ["E", "F", "B", "UP"] +select = [ + "F", # Pyflakes + "E", # Pycodestyle Errors + "W", # Pycodestyle Warnings + "I", # isort + "UP", # pyupgrade + "D", # pydocstyle + "B", # flake8-bugbear +] fix = true ignore-init-module-imports = true -line-length = 120 -ignore = ["B905"] +line-length = 88 +ignore = ["B905", "D1"] target-version = "py39" +extend-include = ["*.ipynb"] +per-file-ignores = { "*.ipynb" = [ + "E402", # allow imports to appear anywhere in Jupyter Notebooks + "E501", # allow long lines in Jupyter Notebooks +] } + +[tool.ruff.lint.isort] +known-first-party = ["sharrow"] + +[tool.ruff.lint.pycodestyle] +max-line-length = 120 + +[tool.ruff.lint.pydocstyle] +convention = "numpy" [tool.pytest.ini_options] minversion = "6.0" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 28d8c14..0000000 --- a/setup.cfg +++ /dev/null @@ -1,36 +0,0 @@ -[metadata] -name = sharrow -author = Cambridge Systematics -author_email = jeffnewman@camsys.com -license = BSD-3-Clause -url = https://github.com/ActivitySim/sharrow -description = numba for ActivitySim-style spec files -long_description = file: README.md -long_description_content_type = text/markdown - -[options] -packages = find: -zip_safe = False -include_package_data = True -python_requires = >=3.7 -install_requires = - numpy >= 1.19 - pandas >= 1.2 - pyarrow >= 3.0.0 - xarray >= 0.20.0 - numba >= 0.54 - sparse - numexpr - filelock - dask - networkx - astunparse;python_version<'3.9' - -[flake8] -exclude = - .git, - __pycache__, - docs/_build, - sharrow/__init__.py -max-line-length = 160 -extend-ignore = E203, E731 diff --git a/setup.py b/setup.py deleted file mode 100644 index 02aeac1..0000000 --- a/setup.py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/bin/env python -from setuptools import setup - -setup(use_scm_version={"fallback_version": "1999"}) diff --git a/sharrow/accessors.py b/sharrow/accessors.py index f83de83..9962db2 100644 --- a/sharrow/accessors.py +++ b/sharrow/accessors.py @@ -1,3 +1,5 @@ +"""Convenience accessor wrappers for xarray objects.""" + import xarray as xr diff --git a/sharrow/categorical.py b/sharrow/categorical.py index abddae2..a39bf89 100644 --- a/sharrow/categorical.py +++ b/sharrow/categorical.py @@ -14,9 +14,7 @@ class ArrayIsNotCategoricalError(TypeError): @xr.register_dataarray_accessor("cat") class _Categorical: - """ - Accessor for pseudo-categorical arrays. - """ + """Accessor for pseudo-categorical arrays.""" __slots__ = ("dataarray",) diff --git a/sharrow/dataset.py b/sharrow/dataset.py index 6ffab94..85b7de2 100755 --- a/sharrow/dataset.py +++ b/sharrow/dataset.py @@ -79,7 +79,7 @@ def clean(s): def construct(source): """ - A generic constructor for creating Datasets from various similar objects. + Create Datasets from various similar objects. Parameters ---------- @@ -111,7 +111,7 @@ def dataset_from_dataframe_fast( sparse: bool = False, preserve_cat: bool = True, ) -> Dataset: - """Convert a pandas.DataFrame into an xarray.Dataset + """Convert a pandas.DataFrame into an xarray.Dataset. Each column will be converted into an independent variable in the Dataset. If the dataframe's index is a MultiIndex, it will be expanded @@ -146,7 +146,6 @@ def dataset_from_dataframe_fast( xarray.DataArray.from_series pandas.DataFrame.to_xarray """ - # this is much faster than the default xarray version when not # using a MultiIndex. @@ -215,7 +214,7 @@ def from_table( index=None, ): """ - Convert a pyarrow.Table into an xarray.Dataset + Convert a pyarrow.Table into an xarray.Dataset. Parameters ---------- @@ -320,7 +319,6 @@ def from_omx( ------- Dataset """ - # handle both larch.OMX and openmatrix.open_file versions if "lar" in type(omx).__module__: omx_data = omx.data @@ -695,9 +693,7 @@ def is_dict_like(value: Any) -> bool: @xr.register_dataset_accessor("single_dim") class _SingleDim: - """ - Convenience accessor for single-dimension datasets. - """ + """Convenience accessor for single-dimension datasets.""" __slots__ = ("dataset", "dim_name") @@ -839,9 +835,7 @@ def eval( @xr.register_dataarray_accessor("single_dim") class _SingleDimArray: - """ - Convenience accessor for single-dimension datasets. - """ + """Convenience accessor for single-dimension datasets.""" __slots__ = ("dataarray", "dim_name") @@ -1194,7 +1188,7 @@ def to_table(self): from .relationships import sparse_array_type def to_numpy(var): - """Coerces wrapped data to numpy and returns a numpy.ndarray""" + """Coerces wrapped data to numpy and returns a numpy.ndarray.""" data = var.data if hasattr(data, "chunks"): data = data.compute() @@ -1218,7 +1212,7 @@ def to_numpy(var): @register_dataset_method def select_and_rename(self, name_dict=None, **names): """ - Select and rename variables from this Dataset + Select and rename variables from this Dataset. Parameters ---------- diff --git a/sharrow/datastore.py b/sharrow/datastore.py index 4be79df..6a98d2c 100644 --- a/sharrow/datastore.py +++ b/sharrow/datastore.py @@ -19,7 +19,7 @@ def timestamp(): class ReadOnlyError(ValueError): - """This object is read-only.""" + """Object is read-only.""" def _read_parquet(filename, index_col=None) -> xr.Dataset: @@ -377,7 +377,7 @@ def _write_metadata(self): def read_metadata(self, checkpoints=None): """ - Read storage metadata + Read storage metadata. Parameters ---------- @@ -478,5 +478,5 @@ def digitize_relationships(self, redigitize=True): @property def relationships_are_digitized(self) -> bool: - """bool : Whether all relationships are digital (by position).""" + """Bool : Whether all relationships are digital (by position).""" return self._tree.relationships_are_digitized diff --git a/sharrow/digital_encoding.py b/sharrow/digital_encoding.py index e260022..95a1cc1 100644 --- a/sharrow/digital_encoding.py +++ b/sharrow/digital_encoding.py @@ -167,7 +167,7 @@ def digitize_by_dictionary(arr, bitwidth=8): bin_edges = (bins[1:] - bins[:-1]) / 2 + bins[:-1] except TypeError: # bins are not numeric - bin_map = {x:n for n,x in enumerate(bins)} + bin_map = {x: n for n, x in enumerate(bins)} u, inv = np.unique(arr.data, return_inverse=True) result.data = np.array([bin_map.get(x) for x in u])[inv].reshape(arr.shape) result.attrs["digital_encoding"] = { @@ -334,7 +334,7 @@ def multivalue_digitize_by_dictionary(ds, encode_vars=None, encoding_name=None): Returns ------- - + Dataset """ logger = logging.getLogger("sharrow") if not isinstance(encoding_name, str): diff --git a/sharrow/flows.py b/sharrow/flows.py index 6345b6d..849cff8 100644 --- a/sharrow/flows.py +++ b/sharrow/flows.py @@ -1022,7 +1022,7 @@ def __initialize_1( bool_wrapping=False, ): """ - Initialize up to the flow_hash + Initialize up to the flow_hash. See main docstring for arguments. """ @@ -1050,7 +1050,9 @@ def __initialize_1( all_raw_names |= attribute_pairs.get(self.tree.root_node_name, set()) all_raw_names |= subscript_pairs.get(self.tree.root_node_name, set()) - dimensions_ordered = presorted(self.tree.sizes, self.dim_order, self.dim_exclude) + dimensions_ordered = presorted( + self.tree.sizes, self.dim_order, self.dim_exclude + ) index_slots = {i: n for n, i in enumerate(dimensions_ordered)} self.arg_name_positions = index_slots self.arg_names = dimensions_ordered @@ -1514,6 +1516,7 @@ def __initialize_2( with_root_node_name=None, ): """ + Second step in initialization, only used if the flow is not cached. Parameters ---------- @@ -1535,7 +1538,6 @@ def __initialize_2( be sure to avoid name conflicts with other flow's in the same directory. """ - if self._hashing_level <= 1: func_code, all_name_tokens = self.init_sub_funcs( defs, @@ -1726,7 +1728,9 @@ def __initialize_2( root_dims = list( presorted( - self.tree._graph.nodes[with_root_node_name]["dataset"].sizes, + self.tree._graph.nodes[with_root_node_name][ + "dataset" + ].sizes, self.dim_order, self.dim_exclude, ) @@ -1938,7 +1942,7 @@ def load_raw(self, rg, args, runner=None, dtype=None, dot=None): # raise the inner key error which is more helpful context = getattr(err, "__context__", None) if context: - raise context + raise context from None else: raise err @@ -2768,8 +2772,7 @@ def function_names(self, x): self._raw_functions[name] = (None, None, set(), []) def _spill(self, all_name_tokens=None): - cmds = [self.tree._spill(all_name_tokens)] - cmds.append("\n") + cmds = ["\n"] cmds.append(f"output_name_positions = {self.output_name_positions!r}") cmds.append(f"function_names = {self.function_names!r}") return "\n".join(cmds) diff --git a/sharrow/relationships.py b/sharrow/relationships.py index ab025d6..e86eb66 100644 --- a/sharrow/relationships.py +++ b/sharrow/relationships.py @@ -154,7 +154,7 @@ def xgather(source, positions, indexes): def _dataarray_to_numpy(self) -> np.ndarray: - """Coerces wrapped data to numpy and returns a numpy.ndarray""" + """Coerces wrapped data to numpy and returns a numpy.ndarray.""" data = self.data if isinstance(data, dask_array_type): data = data.compute() @@ -165,9 +165,7 @@ def _dataarray_to_numpy(self) -> np.ndarray: class Relationship: - """ - Defines a linkage between datasets in a `DataTree`. - """ + """Defines a linkage between datasets in a `DataTree`.""" def __init__( self, @@ -543,7 +541,7 @@ def get_relationship(self, parent, child): return Relationship(parent_data=parent, child_data=child, **attrs) def list_relationships(self) -> list[Relationship]: - """list : List all relationships defined in this tree.""" + """List : List all relationships defined in this tree.""" result = [] for e in self._graph.edges: result.append(self._get_relationship(e)) @@ -904,7 +902,7 @@ def subspaces_iter(self): def contains_subspace(self, key) -> bool: """ - Is this named Dataset in this tree's subspaces + Is this named Dataset in this tree's subspaces. Parameters ---------- @@ -918,7 +916,7 @@ def contains_subspace(self, key) -> bool: def get_subspace(self, key, default_empty=False) -> xr.Dataset: """ - Access named Dataset from this tree's subspaces + Access named Dataset from this tree's subspaces. Parameters ---------- @@ -954,9 +952,7 @@ def namespace_names(self): @property def dims(self): - """ - Mapping from dimension names to lengths across all dataset nodes. - """ + """Mapping from dimension names to lengths across all dataset nodes.""" dims = {} for _k, v in self.subspaces_iter(): for name, length in v.dims.items(): @@ -1005,7 +1001,6 @@ def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): Returns self if dropping inplace, otherwise returns a copy with dimensions dropped. """ - if isinstance(dims, str): dims = [dims] if inplace: @@ -1039,7 +1034,7 @@ def drop_dims(self, dims, inplace=False, ignore_missing_dims=True): while boot_queue: b = boot_queue.pop() booted.add(b) - for (up, dn, _n) in obj._graph.edges.keys(): + for up, dn, _n in obj._graph.edges.keys(): if up == b: boot_queue.add(dn) @@ -1257,21 +1252,6 @@ def setup_flow( with_root_node_name=with_root_node_name, ) - def _spill(self, all_name_tokens=()): - """ - Write backup code for sharrow-lite. - - Parameters - ---------- - all_name_tokens - - Returns - ------- - - """ - cmds = [] - return "\n".join(cmds) - def get_named_array(self, mangled_name): if mangled_name[:2] != "__": raise KeyError(mangled_name) @@ -1311,7 +1291,6 @@ def digitize_relationships(self, inplace=False, redigitize=True): DataTree or None Only returns a copy if not digitizing in-place. """ - if inplace: obj = self else: @@ -1385,7 +1364,7 @@ def mapper_get(x, mapper=mapper): @property def relationships_are_digitized(self): - """bool : Whether all relationships are digital (by position).""" + """Bool : Whether all relationships are digital (by position).""" for e in self._graph.edges: r = self._get_relationship(e) if r.indexing != "position": diff --git a/sharrow/selectors.py b/sharrow/selectors.py index 927cfaa..b057f89 100644 --- a/sharrow/selectors.py +++ b/sharrow/selectors.py @@ -134,7 +134,6 @@ def _filter( ds_ = ds if _names: - result = ( getattr(ds_, _func)(**loaders) .digital_encoding.strip(_names) diff --git a/sharrow/shared_memory.py b/sharrow/shared_memory.py index 3bd273b..d2b9e23 100644 --- a/sharrow/shared_memory.py +++ b/sharrow/shared_memory.py @@ -237,9 +237,7 @@ def __repr__(self): return r def release_shared_memory(self): - """ - Release shared memory allocated to this Dataset. - """ + """Release shared memory allocated to this Dataset.""" release_shared_memory(self._shared_memory_key_) @staticmethod @@ -475,9 +473,7 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): _size_p // _dtype_p.itemsize, dtype=_dtype_p, buffer=buffer[ - position - + _size_d - + _size_i : position + position + _size_d + _size_i : position + _size_d + _size_i + _size_p @@ -507,7 +503,7 @@ def from_shared_memory(cls, key, own_data=False, mode="r+"): @property def shared_memory_size(self): - """int : Size (in bytes) in shared memory, raises ValueError if not shared.""" + """Int : Size (in bytes) in shared memory, raises ValueError if not shared.""" try: return sum(i.size for i in self._shared_memory_objs_) except AttributeError: @@ -515,7 +511,7 @@ def shared_memory_size(self): @property def is_shared_memory(self): - """bool : Whether this Dataset is in shared memory.""" + """Bool : Whether this Dataset is in shared memory.""" try: return sum(i.size for i in self._shared_memory_objs_) > 0 except AttributeError: diff --git a/sharrow/sparse.py b/sharrow/sparse.py index 77448ea..d96035f 100644 --- a/sharrow/sparse.py +++ b/sharrow/sparse.py @@ -95,12 +95,13 @@ def __init__(self, xarray_obj): def set(self, m2t, map_to, map_also=None, name=None): """ + Set the redirection of a dimension. + Parameters ---------- m2t : pandas.Series Mapping maz's to tazs """ - if name is None: name = f"redirect_{map_to}" diff --git a/sharrow/tests/conftest.py b/sharrow/tests/conftest.py index e532de1..f3ba25f 100644 --- a/sharrow/tests/conftest.py +++ b/sharrow/tests/conftest.py @@ -7,9 +7,7 @@ @pytest.fixture def person_dataset() -> xr.Dataset: - """ - Sample persons dataset with dummy data. - """ + """Sample persons dataset with dummy data.""" df = pd.DataFrame( { "Income": [45, 88, 56, 15, 71], @@ -26,9 +24,7 @@ def person_dataset() -> xr.Dataset: @pytest.fixture def household_dataset() -> xr.Dataset: - """ - Sample household dataset with dummy data. - """ + """Sample household dataset with dummy data.""" df = pd.DataFrame( { "n_cars": [1, 2, 1], @@ -40,9 +36,7 @@ def household_dataset() -> xr.Dataset: @pytest.fixture def tours_dataset() -> xr.Dataset: - """ - Sample tours dataset with dummy data. - """ + """Sample tours dataset with dummy data.""" df = pd.DataFrame( { "TourMode": ["Car", "Bus", "Car", "Car", "Walk"], diff --git a/sharrow/utils/tar_zst.py b/sharrow/utils/tar_zst.py index a5e78ad..90ac1f8 100644 --- a/sharrow/utils/tar_zst.py +++ b/sharrow/utils/tar_zst.py @@ -14,7 +14,8 @@ def extract_zst(archive: Path, out_path: Path): """ - extract .zst file + Extract content of zst file to a target file system directory. + works on Windows, Linux, MacOS, etc. Parameters @@ -24,7 +25,6 @@ def extract_zst(archive: Path, out_path: Path): out_path: pathlib.Path or str directory to extract files and directories to """ - if zstandard is None: raise ImportError("pip install zstandard") diff --git a/sharrow/wrappers.py b/sharrow/wrappers.py index d4fa3b4..baeb1e5 100644 --- a/sharrow/wrappers.py +++ b/sharrow/wrappers.py @@ -79,6 +79,7 @@ def igather(source, positions): class DatasetWrapper: def __init__(self, dataset, orig_key, dest_key, time_key=None): """ + Emulate ActivitySim's SkimWrapper. Parameters ---------- @@ -97,7 +98,7 @@ def __init__(self, dataset, orig_key, dest_key, time_key=None): def set_df(self, df): """ - Set the dataframe + Set the dataframe. Parameters ---------- @@ -137,7 +138,7 @@ def set_df(self, df): def lookup(self, key, reverse=False): """ - Generally not called by the user - use __getitem__ instead + Generally not called by the user - use __getitem__ instead. Parameters ---------- @@ -154,7 +155,6 @@ def lookup(self, key, reverse=False): A Series of impedances which are elements of the Skim object and with the same index as df """ - assert self.df is not None, "Call set_df first" if reverse: x = self.positions.rename(columns={"otaz": "dtaz", "dtaz": "otaz"}) @@ -166,7 +166,9 @@ def lookup(self, key, reverse=False): def __getitem__(self, key): """ - Get the lookup for an available skim object (df and orig/dest and column names implicit) + Get the lookup for an available skim object. + + The `df` and orig/dest and column names are implicit. Parameters ---------- @@ -176,6 +178,7 @@ def __getitem__(self, key): Returns ------- impedances: pd.Series with the same index as df - A Series of impedances values from the single Skim with specified key, indexed byt orig/dest pair + A Series of impedances values from the single Skim with specified key, + indexed byt orig/dest pair """ return self.lookup(key)