Skip to content

Commit

Permalink
[pre-commit.ci] pre-commit autoupdate (#117)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/pre-commit/pre-commit-hooks: v4.6.0 → v5.0.0](pre-commit/pre-commit-hooks@v4.6.0...v5.0.0)
- [github.com/psf/black: 24.8.0 → 24.10.0](psf/black@24.8.0...24.10.0)
- [github.com/astral-sh/ruff-pre-commit: v0.6.8 → v0.6.9](astral-sh/ruff-pre-commit@v0.6.8...v0.6.9)
- [github.com/pre-commit/mirrors-clang-format: v19.1.0 → v19.1.1](pre-commit/mirrors-clang-format@v19.1.0...v19.1.1)

* Update ops.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update ops.py

* Update tests.yml

* Update gpu-build.yml

* Update wheels.yml

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dan Foreman-Mackey <foreman.mackey@gmail.com>
  • Loading branch information
pre-commit-ci[bot] and dfm authored Oct 10, 2024
1 parent d1390ce commit 3751be8
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gpu-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:

- uses: actions/setup-python@v5
with:
python-version: 3.11
python-version: 3.12

- name: Install Python dependencies
run: |
Expand Down
7 changes: 2 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
jax-version: ["jax[cpu]"]
include:
- os: ubuntu-latest
jax-version: "'jax[cpu]==0.4.20' 'numpy<2.0'"
jax-version: ["jax"]

steps:
- uses: actions/checkout@v4
Expand All @@ -28,7 +25,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
python-version: "3.12"

- name: Install fftw on ubuntu
if: ${{ matrix.os == 'ubuntu-latest' }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- uses: actions/setup-python@v5
name: Install Python
with:
python-version: "3.9"
python-version: "3.12"
- name: Build sdist
run: |
python -m pip install -U pip
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: "v4.6.0"
rev: "v5.0.0"
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
exclude_types: [json, binary]
- repo: https://github.com/psf/black
rev: "24.8.0"
rev: "24.10.0"
hooks:
- id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.6.8"
rev: "v0.6.9"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: "v19.1.0"
rev: "v19.1.1"
hooks:
- id: clang-format
8 changes: 7 additions & 1 deletion src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from functools import partial, reduce

import numpy as np
import jax
from jax import core
from jax import jit
from jax import numpy as jnp
Expand Down Expand Up @@ -123,7 +124,12 @@ def jvp(prim, args, tangents, *, output_shape, iflag, eps, opts):
)
output_tangents += [s * output_tangent[:, :, n] for n, s in enumerate(scales)]

return output, reduce(ad.add_tangents, output_tangents, ad.Zero.from_value(output))
if jax.version.__version_info__ < (0, 4, 34):
zero = ad.Zero.from_value(output)
else:
zero = ad.Zero.from_primal_value(output)

return output, reduce(ad.add_tangents, output_tangents, zero)


def transpose(doutput, source, *points, output_shape, eps, iflag, opts):
Expand Down

0 comments on commit 3751be8

Please sign in to comment.