Skip to content

Commit

Permalink
[key reuse] improve some key reuse errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 5, 2024
1 parent 9996b1f commit a5ef345
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 128 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/key_reuse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@
reuse_key as reuse_key,
)

from jax.experimental.key_reuse._common import (
from jax.experimental.key_reuse._core import (
KeyReuseError as KeyReuseError,
)
113 changes: 0 additions & 113 deletions jax/experimental/key_reuse/_common.py

This file was deleted.

107 changes: 99 additions & 8 deletions jax/experimental/key_reuse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@

from collections import defaultdict
from functools import partial, reduce, wraps
from typing import Any, Callable
from typing import Any, Callable, NamedTuple

import jax
from jax import lax
from jax import tree_util
from jax.interpreters import batching, mlir
from jax._src import api_util
from jax._src import config
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import pretty_printer as pp
from jax._src import prng
from jax._src import random
from jax._src import util
Expand All @@ -35,14 +35,102 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.util import weakref_lru_cache

from jax.experimental.key_reuse._common import (
consume_p, assert_consumed_value_p, KeyReuseError,
Sink, Source, Forward, KeyReuseSignature
)
from jax.experimental.shard_map import shard_map_p
import numpy as np


class Sink(NamedTuple):
idx: int
mask: bool | np.ndarray = True

def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Sink({self.idx})"
else:
return f"Sink({self.idx}, mask={self.mask})"


class Source(NamedTuple):
idx: int
mask: bool | np.ndarray = True

def __repr__(self):
if isinstance(self.mask, bool) and self.mask:
return f"Source({self.idx})"
else:
return f"Source({self.idx}, mask={self.mask})"

class Forward(NamedTuple):
in_idx: int
out_idx: int


class KeyReuseSignature(NamedTuple):
sinks: list[Sink]
sources: list[Source]
forwards: list[Forward] = []

def check_signature(self, *args, funcname="function", context=None):
for sink in self.sinks:
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
continue
if np.any(args[sink.idx]._consumed & sink.mask):
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
if context:
msg += " {context}"
raise KeyReuseError(msg)

def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
for source in self.sources:
if isinstance(args_out[source.idx], prng.PRNGKeyArray):
args_out[source.idx]._consumed = ~np.asarray(source.mask)
for forward in self.forwards:
arg_in = args_in[forward.in_idx]
arg_out = args_out[forward.out_idx]
if isinstance(arg_in, prng.PRNGKeyArray) and isinstance(arg_out, prng.PRNGKeyArray):
arg_out._consumed = arg_in._consumed


class KeyReuseError(RuntimeError):
pass

consume_p = core.Primitive("consume")
consume_p.def_impl(lambda x: x)
consume_p.def_abstract_eval(lambda x: x)
batching.defvectorized(consume_p)
mlir.register_lowering(
consume_p,
mlir.lower_fun(lambda x: x, multiple_results=False))

def consume(key):
"""Consume the key and return a consumed copy."""
return consume_p.bind(key)


assert_consumed_value_p = core.Primitive("assert_consumed_value")
assert_consumed_value_p.def_impl(lambda x, *, value: x)
assert_consumed_value_p.def_abstract_eval(lambda x, *, value: x)
batching.defvectorized(assert_consumed_value_p)
mlir.register_lowering(
assert_consumed_value_p,
mlir.lower_fun(lambda x, *, value: x, multiple_results=False))

def assert_unconsumed(key):
"""Assert that a key is unconsumed"""
assert_consumed_value_p.bind(key, value=False)

def assert_consumed(key, value=True):
"""Assert that a key is consumed"""
assert_consumed_value_p.bind(key, value=value)


def _check_consumed_value(eqn, consumed):
"""Extra check for use with assert_consumed_value_p"""
expected = eqn.params['value']
Expand Down Expand Up @@ -341,17 +429,20 @@ def key_reuse_impl_rule(prim, original_rule):
def key_reuse_impl(*args, **kwargs):
if config.enable_key_reuse_checks.value:
if prim == pjit.pjit_p:
funcname = "jit-compiled function"
jaxpr = kwargs['jaxpr'].jaxpr
signature = get_jaxpr_type_signature(jaxpr)
elif prim in key_reuse_signatures:
jaxpr = prim
funcname = str(prim)
jaxpr = None
signature = key_reuse_signatures[prim]
elif prim in key_reuse_signatures_dynamic:
funcname = str(prim)
jaxpr = jax.make_jaxpr(partial(prim.bind, **kwargs))(*args).jaxpr
signature = get_jaxpr_type_signature(jaxpr)
else:
raise RuntimeError(f"Internal: no key reuse rule for primitive {prim}")
signature.check_signature(*args, jaxpr=jaxpr)
signature.check_signature(*args, funcname=funcname)
result = original_rule(*args, **kwargs)
signature.update_consumption(args, result if prim.multiple_results else [result])
return result
Expand Down
13 changes: 7 additions & 6 deletions tests/key_reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp
from jax._src import prng
from jax._src import test_util as jtu
from jax.experimental.key_reuse._common import (
from jax.experimental.key_reuse._core import (
assert_consumed, assert_unconsumed, consume, consume_p)
from jax.experimental.key_reuse import _core, KeyReuseError

Expand Down Expand Up @@ -587,28 +587,29 @@ def f_good(x, key):


class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key at index 0 passed to function"
bits_msg = "In random_bits, key values a are already consumed."
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
traced_bits_msg = "In random_bits, key values a are already consumed."

def test_simple_reuse_nojit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with jax.disable_jit():
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
with self.assertRaisesRegex(KeyReuseError, self.eager_bits_msg):
_ = jax.random.bits(key)

def test_simple_key_reuse_jit(self):
key = jax.random.key(0)
_ = jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.jit_msg):
_ = jax.random.bits(key)
_ = jax.jit(jax.random.bits)(key)

def test_key_reuse_within_jit(self):
@jax.jit
def f():
key = jax.random.key(0)
return jax.random.bits(key) + jax.random.bits(key)
with self.assertRaisesRegex(KeyReuseError, self.bits_msg):
with self.assertRaisesRegex(KeyReuseError, self.traced_bits_msg):
f()


Expand Down

0 comments on commit a5ef345

Please sign in to comment.