From b79aab5d6b0bd0ab7655f31373848156bf9f08a9 Mon Sep 17 00:00:00 2001 From: nikitaved Date: Mon, 18 Mar 2024 10:51:31 -0400 Subject: [PATCH] collections.namedtuple: add lookaside --- thunder/core/interpreter.py | 93 ++++++++++++++++++++++++++++++- thunder/tests/test_interpreter.py | 45 +++++++++++++++ 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index aea2baf52e..2f505a1b4b 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -404,6 +404,10 @@ def __init__( self._uncacheable_classes = uncacheable_classes + @property + def with_provenance_tracking(self): + return self._with_provenance_tracking + def interpret(self, inst: dis.Instruction, /, **interpreter_state) -> None | int | INTERPRETER_SIGNALS: return self._opcode_interpreter(inst, **interpreter_state) @@ -887,6 +891,7 @@ class PseudoInst(str, enum.Enum): BINARY_SUBSCR = "BINARY_SUBSCR" BUILD_DICT = "BUILD_DICT" BUILD_TUPLE = "BUILD_TUPLE" + BUILD_NAMEDTUPLE = "BUILD_NAMEDTUPLE" CONSTANT = "CONSTANT" EXCEPTION_HANDLER = "EXCEPTION_HANDLER" INPUT_ARGS = "INPUT_ARGS" @@ -2589,6 +2594,73 @@ def impl(self, other): return _interpret_call(impl, self, other) +def _collections_namedtuple_lookaside( + typename: str, + field_names: Iterable[str], + *, + rename: bool = False, + defaults: None | Iterable[Any] = None, + module: None | str = None): + # Type checks { + assert wrapped_isinstance(typename, str) + assert wrapped_isinstance(field_names, Iterable) + assert wrapped_isinstance(rename, bool) + if defaults is not None: + assert wrapped_isinstance(defaults, Iterable) + if module is not None: + assert wrapped_isinstance(module, str) + # } + + # Wrap defaults { + if not isinstance(rename, WrappedValue): + rename = wrap_const(rename) + + if defaults is None: + defaults = wrap_const(defaults) + + if module is None: + # To prevent taking module from the direct caller, + # we use the module's name from the active frame + curr_frame = get_interpreterruntimectx().frame_stack[-1] + module = curr_frame.globals.value.get("__name__", None) + module = wrap_const(module) + # } + + # Run opaque namedtuple { + @interpreter_needs_wrap + def create_namedtuple(typename: str, field_names: str, **kwargs): + namedtuple_type = collections.namedtuple(typename, field_names, **kwargs) + return namedtuple_type + + namedtuple_type = create_namedtuple(typename, field_names, rename=rename, defaults=defaults, module=module) + if namedtuple_type is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return namedtuple_type + + assert wrapped_isinstance(namedtuple_type, type) + # } + + # Short circuit if provenance is not being recorded ... { + ctx: InterpreterCompileCtx = get_interpretercompilectx() + if not ctx.with_provenance_tracking: + return namedtuple_type + # } + + # ... otherwise wrap returned type { + unamedtuple_type = unwrap(namedtuple_type) + + @functools.wraps(unamedtuple_type, updated=()) + class NewNamedTuple(unamedtuple_type): + @classmethod + @property + def is_namedtuple(cls): + return True + + namedtuple_type = WrappedValue(NewNamedTuple, provenance=namedtuple_type.provenance) + # } + + return namedtuple_type + + _default_lookaside_map: dict[Callable, Callable] = { # Jit lookasides is_jitting: _is_jitting_lookaside, @@ -2612,6 +2684,7 @@ def impl(self, other): isinstance: _isinstance_lookaside, functools.reduce: _functools_reduce_lookaside, operator.getitem: _getitem_lookaside, + collections.namedtuple: _collections_namedtuple_lookaside, } @@ -2619,9 +2692,11 @@ def impl(self, other): # immutuable sequences (tuples) are created with contents in __new__ and __init__ is a nop # (object.__init__, actually). def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): + new_tuple_type = cls.value + assert issubclass(new_tuple_type, tuple) + if iterable == (): iterable = wrap_const(()) - assert cls.value is tuple if isinstance(iterable.value, (list, tuple)): # special case to avoid infinite recursion @@ -2648,8 +2723,20 @@ def _tuple_new_provenance_tracking_lookaside(cls, iterable=(), /): else: item_wrappers.append(wv) - ures = tuple(w.value for w in item_wrappers) - pr = ProvenanceRecord(PseudoInst.BUILD_TUPLE, inputs=[w.provenance for w in item_wrappers]) + # Construction of namedtuples may raise + try: + ures = tuple(w.value for w in item_wrappers) + # Named tuples expect varargs, not iterables at new/init + if hasattr(new_tuple_type, 'is_namedtuple'): + ures = new_tuple_type(*ures) + build_inst = PseudoInst.BUILD_NAMEDTUPLE + else: + ures = new_tuple_type(ures) + build_inst = PseudoInst.BUILD_TUPLE + except Exception as e: + return do_raise(e) + + pr = ProvenanceRecord(build_inst, inputs=[w.provenance for w in item_wrappers]) res = wrap(ures, provenance=pr) res.item_wrappers = item_wrappers diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index ab1b45289c..0c8b7202d2 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -1030,6 +1030,51 @@ def add(x, y): assert jfoo((1, 2, 3), jadd) == 6 +def test_namedtuple_lookaside(jit): + from collections import namedtuple + + typename = "MyNamedTuple" + field_names = ('a', 'b', 'c') + + # Test returnign just the type { + def f(): + return namedtuple(typename, field_names) + + jf = jit(f) + + jtype = jf() + assert isinstance(jtype, type) + assert jtype.__name__ == typename + assert all(hasattr(jtype, field) for field in field_names) + # } + + # Test accessing elements { + a = torch.rand(1) + b = torch.rand(1) + c = torch.rand(1) + + def f(a, b, c): + nt = namedtuple(typename, field_names) + obj = nt(a, b, c) + return obj[0] + + jf = jit(f) + + assert f(a, b, c) is a + assert jf(a, b, c) is a + + def f(a, b, c): + nt = namedtuple(typename, field_names) + obj = nt(a, b, c) + return obj.a + + jf = jit(f) + + assert f(a, b, c) is a + assert jf(a, b, c) is a + # } + + def test_calling_methods(jit): jitting = False