Skip to content

Commit

Permalink
collections.namedtuple: add lookaside
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved committed Mar 25, 2024
1 parent 2e0bb61 commit b79aab5
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 3 deletions.
93 changes: 90 additions & 3 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,10 @@ def __init__(

self._uncacheable_classes = uncacheable_classes

@property
def with_provenance_tracking(self):
return self._with_provenance_tracking

def interpret(self, inst: dis.Instruction, /, **interpreter_state) -> None | int | INTERPRETER_SIGNALS:
return self._opcode_interpreter(inst, **interpreter_state)

Expand Down Expand Up @@ -887,6 +891,7 @@ class PseudoInst(str, enum.Enum):
BINARY_SUBSCR = "BINARY_SUBSCR"
BUILD_DICT = "BUILD_DICT"
BUILD_TUPLE = "BUILD_TUPLE"
BUILD_NAMEDTUPLE = "BUILD_NAMEDTUPLE"
CONSTANT = "CONSTANT"
EXCEPTION_HANDLER = "EXCEPTION_HANDLER"
INPUT_ARGS = "INPUT_ARGS"
Expand Down Expand Up @@ -2589,6 +2594,73 @@ def impl(self, other):
return _interpret_call(impl, self, other)


def _collections_namedtuple_lookaside(
typename: str,
field_names: Iterable[str],
*,
rename: bool = False,
defaults: None | Iterable[Any] = None,
module: None | str = None):
# Type checks {
assert wrapped_isinstance(typename, str)
assert wrapped_isinstance(field_names, Iterable)
assert wrapped_isinstance(rename, bool)
if defaults is not None:
assert wrapped_isinstance(defaults, Iterable)
if module is not None:
assert wrapped_isinstance(module, str)
# }

# Wrap defaults {
if not isinstance(rename, WrappedValue):
rename = wrap_const(rename)

if defaults is None:
defaults = wrap_const(defaults)

if module is None:
# To prevent taking module from the direct caller,
# we use the module's name from the active frame
curr_frame = get_interpreterruntimectx().frame_stack[-1]
module = curr_frame.globals.value.get("__name__", None)
module = wrap_const(module)
# }

# Run opaque namedtuple {
@interpreter_needs_wrap
def create_namedtuple(typename: str, field_names: str, **kwargs):
namedtuple_type = collections.namedtuple(typename, field_names, **kwargs)
return namedtuple_type

namedtuple_type = create_namedtuple(typename, field_names, rename=rename, defaults=defaults, module=module)
if namedtuple_type is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return namedtuple_type

assert wrapped_isinstance(namedtuple_type, type)
# }

# Short circuit if provenance is not being recorded ... {
ctx: InterpreterCompileCtx = get_interpretercompilectx()
if not ctx.with_provenance_tracking:
return namedtuple_type
# }

# ... otherwise wrap returned type {
unamedtuple_type = unwrap(namedtuple_type)

@functools.wraps(unamedtuple_type, updated=())
class NewNamedTuple(unamedtuple_type):
@classmethod
@property
def is_namedtuple(cls):
return True

namedtuple_type = WrappedValue(NewNamedTuple, provenance=namedtuple_type.provenance)
# }

return namedtuple_type


_default_lookaside_map: dict[Callable, Callable] = {
# Jit lookasides
is_jitting: _is_jitting_lookaside,
Expand All @@ -2612,16 +2684,19 @@ def impl(self, other):
isinstance: _isinstance_lookaside,
functools.reduce: _functools_reduce_lookaside,
operator.getitem: _getitem_lookaside,
collections.namedtuple: _collections_namedtuple_lookaside,
}


# While mutuable sequences (lists) are created empty in __new__ and populated in __init__,
# immutuable sequences (tuples) are created with contents in __new__ and __init__ is a nop
# (object.__init__, actually).
def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /):
new_tuple_type = cls.value
assert issubclass(new_tuple_type, tuple)

if iterable == ():
iterable = wrap_const(())
assert cls.value is tuple

if isinstance(iterable.value, (list, tuple)):
# special case to avoid infinite recursion
Expand All @@ -2648,8 +2723,20 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /):
else:
item_wrappers.append(wv)

ures = tuple(w.value for w in item_wrappers)
pr = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[w.provenance for w in item_wrappers])
# Construction of namedtuples may raise
try:
ures = tuple(w.value for w in item_wrappers)
# Named tuples expect varargs, not iterables at new/init
if hasattr(new_tuple_type, 'is_namedtuple'):
ures = new_tuple_type(*ures)
build_inst = PseudoInst.BUILD_NAMEDTUPLE
else:
ures = new_tuple_type(ures)
build_inst = PseudoInst.BUILD_TUPLE
except Exception as e:
return do_raise(e)

pr = ProvenanceRecord(build_inst, inputs=[w.provenance for w in item_wrappers])
res = wrap(ures, provenance=pr)
res.item_wrappers = item_wrappers

Expand Down
45 changes: 45 additions & 0 deletions thunder/tests/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,51 @@ def add(x, y):
assert jfoo((1, 2, 3), jadd) == 6


def test_namedtuple_lookaside(jit):
from collections import namedtuple

typename = "MyNamedTuple"
field_names = ('a', 'b', 'c')

# Test returnign just the type {
def f():
return namedtuple(typename, field_names)

jf = jit(f)

jtype = jf()
assert isinstance(jtype, type)
assert jtype.__name__ == typename
assert all(hasattr(jtype, field) for field in field_names)
# }

# Test accessing elements {
a = torch.rand(1)
b = torch.rand(1)
c = torch.rand(1)

def f(a, b, c):
nt = namedtuple(typename, field_names)
obj = nt(a, b, c)
return obj[0]

jf = jit(f)

assert f(a, b, c) is a
assert jf(a, b, c) is a

def f(a, b, c):
nt = namedtuple(typename, field_names)
obj = nt(a, b, c)
return obj.a

jf = jit(f)

assert f(a, b, c) is a
assert jf(a, b, c) is a
# }


def test_calling_methods(jit):
jitting = False

Expand Down

0 comments on commit b79aab5

Please sign in to comment.