diff --git a/.github/scripts/upstream_install.py b/.github/scripts/upstream_install.py index 334795790..87f989d7d 100644 --- a/.github/scripts/upstream_install.py +++ b/.github/scripts/upstream_install.py @@ -14,7 +14,7 @@ def install_deps() -> None: "--upgrade", ) upstream_deps = ( - "git+https://github.com/dask/dask.git#egg=dask[array]", + "git+https://github.com/dask/dask.git#egg=dask[array,dataframe]", "git+https://github.com/dask/distributed.git#egg=distributed", "git+https://github.com/dask/dask-ml.git#egg=dask-ml", "git+https://github.com/pandas-dev/pandas#egg=pandas", diff --git a/.github/workflows/upstream.yml b/.github/workflows/upstream.yml index c4a663137..b3e2a5b4d 100644 --- a/.github/workflows/upstream.yml +++ b/.github/workflows/upstream.yml @@ -1,6 +1,7 @@ name: Upstream on: + pull_request: push: schedule: - cron: "0 1 * * *" @@ -14,7 +15,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.11"] steps: - uses: actions/checkout@v2 diff --git a/requirements-numpy2.txt b/requirements-numpy2.txt index 491e63fb8..16d16f990 100644 --- a/requirements-numpy2.txt +++ b/requirements-numpy2.txt @@ -1,7 +1,7 @@ numpy < 2.1 xarray -dask[array] >= 2023.01.0, <= 2024.8.0 -distributed >= 2023.01.0, <= 2024.8.0 +dask[array] >= 2023.01.0, != 2024.8.1, != 2024.9.* +distributed >= 2023.01.0, != 2024.8.1, != 2024.9.* dask-ml scipy typing-extensions diff --git a/requirements.txt b/requirements.txt index dcc24d89b..eb117e179 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ numpy < 2 xarray -dask[array] >= 2023.01.0, <= 2024.8.0 -distributed >= 2023.01.0, <= 2024.8.0 +dask[array,dataframe] >= 2023.01.0, != 2024.8.1, != 2024.9.* +distributed >= 2023.01.0, != 2024.8.1, != 2024.9.* dask-ml scipy typing-extensions diff --git a/setup.cfg b/setup.cfg index 8bacbf5f7..fa506621d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,8 +30,8 @@ python_requires = >=3.9 install_requires = numpy < 2 xarray - dask[array] >= 2022.01.0, <= 2024.8.0 - distributed >= 2022.01.0, <= 2024.8.0 + dask[array,dataframe] >= 2022.01.0, != 2024.8.1, != 2024.9.* + distributed >= 2022.01.0, != 2024.8.1, != 2024.9.* dask-ml scipy zarr >= 2.10.0, != 2.11.0, != 2.11.1, != 2.11.2, < 3 diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index 9360e318c..9f3862985 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -680,7 +680,7 @@ def variant_stats( -------- :func:`count_variant_genotypes` """ - from .aggregation_numba_fns import count_hom + from .aggregation_numba_fns import count_hom_new_axis variables.validate(ds, {call_genotype: variables.call_genotype_spec}) mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) @@ -697,7 +697,7 @@ def variant_stats( G = da.asarray(ds[call_genotype].data) H = xr.DataArray( da.map_blocks( - lambda *args: count_hom(*args)[:, np.newaxis, :], + count_hom_new_axis, G, np.zeros(3, np.uint64), drop_axis=2, @@ -796,7 +796,7 @@ def sample_stats( ValueError If the dataset contains mixed-ploidy genotype calls. """ - from .aggregation_numba_fns import count_hom + from .aggregation_numba_fns import count_hom_new_axis variables.validate(ds, {call_genotype: variables.call_genotype_spec}) mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False) @@ -805,7 +805,7 @@ def sample_stats( GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data) H = xr.DataArray( da.map_blocks( - lambda *args: count_hom(*args)[:, np.newaxis, :], + count_hom_new_axis, GT, np.zeros(3, np.uint64), drop_axis=2, diff --git a/sgkit/stats/aggregation_numba_fns.py b/sgkit/stats/aggregation_numba_fns.py index 3335f5457..b84b92a09 100644 --- a/sgkit/stats/aggregation_numba_fns.py +++ b/sgkit/stats/aggregation_numba_fns.py @@ -2,6 +2,8 @@ # in a separate file here, and imported dynamically to avoid # initial compilation overhead. +import numpy as np + from sgkit.accelerate import numba_guvectorize, numba_jit from sgkit.typing import ArrayLike @@ -102,3 +104,7 @@ def count_hom( index = _classify_hom(genotypes[i]) if index >= 0: out[index] += 1 + + +def count_hom_new_axis(genotypes: ArrayLike, _: ArrayLike) -> ArrayLike: + return count_hom(genotypes, _)[:, np.newaxis, :] diff --git a/sgkit/stats/popgen.py b/sgkit/stats/popgen.py index d000bdbee..e201dfc98 100644 --- a/sgkit/stats/popgen.py +++ b/sgkit/stats/popgen.py @@ -595,9 +595,7 @@ def pbs( cohorts = cohorts or list(itertools.combinations(range(n_cohorts), 3)) # type: ignore ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts_0", None)) - p = da.map_blocks( - lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64 - ) + p = da.map_blocks(_pbs_cohorts, t, ct, chunks=shape, new_axis=3, dtype=np.float64) assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts) new_ds = create_dataset(