Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[key reuse] improve some key reuse errors. #20070

Merged
merged 1 commit into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/experimental/key_reuse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@
reuse_key as reuse_key,
)

from jax.experimental.key_reuse._common import (
from jax.experimental.key_reuse._core import (
KeyReuseError as KeyReuseError,
)
113 changes: 0 additions & 113 deletions jax/experimental/key_reuse/_common.py

This file was deleted.

107 changes: 99 additions & 8 deletions jax/experimental/key_reuse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@

from collections import defaultdict
from functools import partial, reduce, wraps
from typing import Any, Callable
from typing import Any, Callable, NamedTuple

import jax
from jax import lax
from jax import tree_util
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import pretty_printer as pp
from jax._src import prng
from jax._src import random
from jax._src import util
Expand All @@ -35,14 +35,102 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache

from jax.experimental.key_reuse._common import (
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, Forward, KeyReuseSignature
)
from jax.experimental.shard_map import shard_map_p
import numpy as np


class Sink(NamedTuple):
idx: int
mask: bool | np.ndarray = True

def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Sink({self.idx})"
else:
return f"Sink({self.idx}, mask={self.mask})"


class Source(NamedTuple):
idx: int
mask: bool | np.ndarray = True

def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Source({self.idx})"
else:
return f"Source({self.idx}, mask={self.mask})"

class Forward(NamedTuple):
in_idx: int
out_idx: int


class KeyReuseSignature(NamedTuple):
sinks: list[Sink]
sources: list[Source]
forwards: list[Forward] = []

def check_signature(self, *args, funcname="function", context=None):
for sink in self.sinks:
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
continue
if np.any(args[sink.idx]._consumed & sink.mask):
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
if context:
msg += " {context}"
raise KeyReuseError(msg)

def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
for source in self.sources:
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
args_out[source.idx]._consumed = ~np.asarray(source.mask)
for forward in self.forwards:
arg_in = args_in[forward.in_idx]
arg_out = args_out[forward.out_idx]
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
arg_out._consumed = arg_in._consumed


class KeyReuseError(RuntimeError):
pass

consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_abstract_eval(lambda x: x)
batching.defvectorized(consume_p)
mlir.register_lowering(
consume_p,
mlir.lower_fun(lambda x: x, multiple_results=False))

def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)


assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
batching.defvectorized(assert_consumed_value_p)
mlir.register_lowering(
assert_consumed_value_p,
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))

def assert_unconsumed(key):
"""Assert that a key is unconsumed"""
assert_consumed_value_p.bind(key, value=False)

def assert_consumed(key, value=True):
"""Assert that a key is consumed"""
assert_consumed_value_p.bind(key, value=value)


def _check_consumed_value(eqn, consumed):
"""Extra check for use with assert_consumed_value_p"""
expected = eqn.params['value']
Expand Down Expand Up @@ -341,17 +429,20 @@ def key_reuse_impl_rule(prim, original_rule):
def key_reuse_impl(*args, **kwargs):
if config.enable_key_reuse_checks.value:
if prim == pjit.pjit_p:
funcname = "jit-compiled function"
jaxpr = kwargs['jaxpr'].jaxpr
signature = get_jaxpr_type_signature(jaxpr)
elif prim in key_reuse_signatures:
jaxpr = prim
funcname = str(prim)
jaxpr = None
signature = key_reuse_signatures[prim]
elif prim in key_reuse_signatures_dynamic:
funcname = str(prim)
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
signature = get_jaxpr_type_signature(jaxpr)
else:
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
signature.check_signature(*args, jaxpr=jaxpr)
signature.check_signature(*args, funcname=funcname)
result = original_rule(*args, **kwargs)
signature.update_consumption(args, result if prim.multiple_results else [result])
return result
Expand Down
13 changes: 7 additions & 6 deletions tests/key_reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp
from jax._src import prng
from jax._src import test_util as jtu
from jax.experimental.key_reuse._common import (
from jax.experimental.key_reuse._core import (
assert_consumed, assert_unconsumed, consume, consume_p)
from jax.experimental.key_reuse import _core, KeyReuseError

Expand Down Expand Up @@ -587,28 +587,29 @@ def f_good(x, key):


class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key at index 0 passed to function"
bits_msg = "In random_bits, key values a are already consumed."
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
traced_bits_msg = "In random_bits, key values a are already consumed."

def test_simple_reuse_nojit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with jax.disable_jit():
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
_ = jax.random.bits(key)

def test_simple_key_reuse_jit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.random.bits(key)
_ = jax.jit(jax.random.bits)(key)

def test_key_reuse_within_jit(self):
@jax.jit
def f():
key = jax.random.key(0)
return jax.random.bits(key) + jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.bits_msg):
with self.assertRaisesRegex(KeyReuseError, self.traced_bits_msg):
f()


Expand Down