Skip to content

Commit

Permalink
Rollback of #22869
Browse files Browse the repository at this point in the history
This is causing breakages due to overly-restrictive checks on kwargs

Reverts 893ae6e

PiperOrigin-RevId: 660803968
  • Loading branch information
Jake VanderPlas authored and jax authors committed Aug 8, 2024
1 parent 9fbc51b commit 551f729
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 49 deletions.
11 changes: 0 additions & 11 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,17 +556,6 @@ def _assert_no_intersection(static_argnames, donate_argnames):
f"{out} appear in both static_argnames and donate_argnames")


def resolve_kwargs(fun: Callable, args, kwargs):
if isinstance(fun, partial):
fun = lambda *args, **kwargs: None
ba = inspect.signature(fun).bind(*args, **kwargs)
ba.apply_defaults()
if ba.kwargs:
raise TypeError("keyword arguments could not be resolved to positions")
else:
return ba.args


def _dtype(x):
try:
return dtypes.result_type(x)
Expand Down
9 changes: 2 additions & 7 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters.batching import not_mapped
Expand Down Expand Up @@ -64,12 +64,7 @@ def def_vmap(self, vmap_rule: Callable) -> Callable:

@traceback_util.api_boundary
def __call__(self, *args, **kwargs):
fun_name = getattr(self.fun, "__name__", str(self.fun))
if not self.vmap_rule:
raise AttributeError(
f"No batching rule defined for custom_vmap function {fun_name} "
"using def_vmap.")
args = resolve_kwargs(self.fun, args, kwargs)
assert not kwargs
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
Expand Down
21 changes: 16 additions & 5 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Callable, Sequence
import dataclasses
from functools import update_wrapper, reduce, partial, wraps
import inspect
from typing import Any, Generic, TypeVar

from jax._src import config
Expand All @@ -29,8 +30,7 @@
from jax._src import traceback_util
from jax._src.ad_util import (
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
from jax._src.api_util import (
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.core import raise_to_shaped
from jax._src.errors import UnexpectedTracerError
from jax._src.interpreters import ad
Expand All @@ -56,6 +56,17 @@

### util

def _resolve_kwargs(fun, args, kwargs):
if isinstance(fun, partial):
# functools.partial should have an opaque signature.
fun = lambda *args, **kwargs: None
ba = inspect.signature(fun).bind(*args, **kwargs)
ba.apply_defaults()
if ba.kwargs:
raise TypeError("keyword arguments could not be resolved to positions")
else:
return ba.args

def _initial_style_jaxpr(fun, in_avals):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return jaxpr, consts
Expand Down Expand Up @@ -229,7 +240,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
args = resolve_kwargs(self.fun, args, kwargs)
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
args = tuple(_stop_gradient(x) if i in nondiff_argnums else x
Expand Down Expand Up @@ -588,7 +599,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = resolve_kwargs(self.fun, args, kwargs)
args = _resolve_kwargs(self.fun, args, kwargs)
if self.optimize_remat:
fwd = optimize_remat_of_custom_vjp_fwd(
self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums,
Expand Down Expand Up @@ -1440,7 +1451,7 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
# above and it would be good to consolidate it.
primal_name = getattr(fun, "__name__", str(fun))
fwd_name = getattr(fwd, "__name__", str(fwd))
args = resolve_kwargs(fwd, args, kwargs)
args = _resolve_kwargs(fwd, args, kwargs)
if nondiff_argnums:
for i in nondiff_argnums: _check_for_tracers(args[i])
nondiff_argnums_ = set(nondiff_argnums)
Expand Down
26 changes: 0 additions & 26 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10798,32 +10798,6 @@ def g(x, a):

self.assertAllClose(y, (x + a)**2)

def test_kwargs(self):
@jax.custom_batching.custom_vmap
def f(x): return jnp.sin(x)

@f.def_vmap
def rule(axis_size, in_batched, xs):
xs_batched, = in_batched
self.assertEqual(xs_batched, True)
self.assertEqual(axis_size, xs.shape[0])
return jnp.cos(xs), xs_batched

x, xs = jnp.array(1.), jnp.arange(3)
y = f(x=x)
self.assertAllClose(y, jnp.sin(x))
ys = api.vmap(f)(x=xs)
self.assertAllClose(ys, jnp.cos(xs))

def test_undefined_rule(self):
@jax.custom_batching.custom_vmap
def f(x): return jnp.sin(x)

with self.assertRaisesRegex(
AttributeError, "No batching rule defined for custom_vmap function f"):
f(0.5)


class CustomApiTest(jtu.JaxTestCase):
"""Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs"""

Expand Down

0 comments on commit 551f729

Please sign in to comment.