From 551f72979c85b27ca77ce5851a9fa6e5757bf157 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 8 Aug 2024 05:59:19 -0700 Subject: [PATCH] Rollback of #22869 This is causing breakages due to overly-restrictive checks on kwargs Reverts 893ae6eb800851b1c17c437982608bb59d3bc6be PiperOrigin-RevId: 660803968 --- jax/_src/api_util.py | 11 ----------- jax/_src/custom_batching.py | 9 ++------- jax/_src/custom_derivatives.py | 21 ++++++++++++++++----- tests/api_test.py | 26 -------------------------- 4 files changed, 18 insertions(+), 49 deletions(-) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index 481dec0065a5..dd1cdcbe6bb8 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -556,17 +556,6 @@ def _assert_no_intersection(static_argnames, donate_argnames): f"{out} appear in both static_argnames and donate_argnames") -def resolve_kwargs(fun: Callable, args, kwargs): - if isinstance(fun, partial): - fun = lambda *args, **kwargs: None - ba = inspect.signature(fun).bind(*args, **kwargs) - ba.apply_defaults() - if ba.kwargs: - raise TypeError("keyword arguments could not be resolved to positions") - else: - return ba.args - - def _dtype(x): try: return dtypes.result_type(x) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 4b859e910165..4d41849b75d3 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -27,7 +27,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs +from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters.batching import not_mapped @@ -64,12 +64,7 @@ def def_vmap(self, vmap_rule: Callable) -> Callable: @traceback_util.api_boundary def __call__(self, *args, **kwargs): - fun_name = getattr(self.fun, "__name__", str(self.fun)) - if not self.vmap_rule: - raise AttributeError( - f"No batching rule defined for custom_vmap function {fun_name} " - "using def_vmap.") - args = resolve_kwargs(self.fun, args, kwargs) + assert not kwargs args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index bc9f7a687dcb..d27b0efc7e5e 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,6 +17,7 @@ from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial, wraps +import inspect from typing import Any, Generic, TypeVar from jax._src import config @@ -29,8 +30,7 @@ from jax._src import traceback_util from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) -from jax._src.api_util import ( - argnums_partial, flatten_fun_nokwargs, resolve_kwargs) +from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad @@ -56,6 +56,17 @@ ### util +def _resolve_kwargs(fun, args, kwargs): + if isinstance(fun, partial): + # functools.partial should have an opaque signature. + fun = lambda *args, **kwargs: None + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + raise TypeError("keyword arguments could not be resolved to positions") + else: + return ba.args + def _initial_style_jaxpr(fun, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts @@ -229,7 +240,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) - args = resolve_kwargs(self.fun, args, kwargs) + args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -588,7 +599,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) - args = resolve_kwargs(self.fun, args, kwargs) + args = _resolve_kwargs(self.fun, args, kwargs) if self.optimize_remat: fwd = optimize_remat_of_custom_vjp_fwd( self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, @@ -1440,7 +1451,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # above and it would be good to consolidate it. primal_name = getattr(fun, "__name__", str(fun)) fwd_name = getattr(fwd, "__name__", str(fwd)) - args = resolve_kwargs(fwd, args, kwargs) + args = _resolve_kwargs(fwd, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) diff --git a/tests/api_test.py b/tests/api_test.py index 4aafc42b7a0e..cb0d7c0d40c7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -10798,32 +10798,6 @@ def g(x, a): self.assertAllClose(y, (x + a)**2) - def test_kwargs(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - @f.def_vmap - def rule(axis_size, in_batched, xs): - xs_batched, = in_batched - self.assertEqual(xs_batched, True) - self.assertEqual(axis_size, xs.shape[0]) - return jnp.cos(xs), xs_batched - - x, xs = jnp.array(1.), jnp.arange(3) - y = f(x=x) - self.assertAllClose(y, jnp.sin(x)) - ys = api.vmap(f)(x=xs) - self.assertAllClose(ys, jnp.cos(xs)) - - def test_undefined_rule(self): - @jax.custom_batching.custom_vmap - def f(x): return jnp.sin(x) - - with self.assertRaisesRegex( - AttributeError, "No batching rule defined for custom_vmap function f"): - f(0.5) - - class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""