From ee736c11d4b1b1c7f7029369c034ec4fced5aece Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 14 Jan 2025 15:41:51 +0100 Subject: [PATCH] more general setattr --- thunder/core/interpreter.py | 27 ++++--- thunder/core/jit_ext.py | 149 ++++++++++++++++++++++++++++-------- thunder/core/prims.py | 43 +++++++++++ 3 files changed, 178 insertions(+), 41 deletions(-) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 6557df1802..b68dc390c3 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -932,7 +932,9 @@ class PseudoInst(str, enum.Enum): SUPER = "SUPER" BUILTINS = "BUILTINS" STORE_SUBSCR = "STORE_SUBSCR" + STORE_ATTR = "STORE_ATTR" LIST_TO_TUPLE = "LIST_TO_TUPLE" + NEW = "NEW" @dataclasses.dataclass @@ -2073,9 +2075,13 @@ def impl(fn, iterable, initializer, null): return _interpret_call(impl, fn, iterable, initializer, null) +class ThunderInterpreterObject: + pass + + # An iterator to be returned from Sequence.__iter__ lookasides below. This will be run in the interpreter # Note: this potentially might imitate a list_iterator / tuple_iterator more... -class SequenceIter: +class SequenceIter(ThunderInterpreterObject): def __init__(self, s, is_reversed=False): self.s = s self.next_pos = 0 if not is_reversed else len(s) - 1 @@ -2377,7 +2383,7 @@ def reverse(self, /): return wrap_const(None) -class MappingKeysIterator(Iterator): +class MappingKeysIterator(Iterator, ThunderInterpreterObject): # note: the __init__ will be executed by Python itself, and # the caller needs to set up the wrapped_attribute for _mapping # The other methods are called through the interpreter mechanism. @@ -2395,7 +2401,7 @@ def __next__(self): return k -class MappingKeysView: +class MappingKeysView(ThunderInterpreterObject): def __init__(self, mapping): self._mapping = mapping @@ -2425,7 +2431,7 @@ def __reversed__(self): return mapping_iter -class MappingValuesIterator: +class MappingValuesIterator(ThunderInterpreterObject): def __init__(self, mapping, is_reversed=False): self._mapping = mapping if is_reversed: @@ -2440,7 +2446,7 @@ def __next__(self): return dict.__getitem__(self._mapping, next(self._key_iter)) -class MappingValuesWrapper: +class MappingValuesWrapper(ThunderInterpreterObject): def __init__(self, mapping): self._mapping = mapping @@ -2448,7 +2454,7 @@ def __iter__(self): return MappingValuesIterator(self._mapping) -class MappingItemsIterator: +class MappingItemsIterator(ThunderInterpreterObject): def __init__(self, mapping, is_reversed=False): self._mapping = mapping if is_reversed: @@ -2464,7 +2470,7 @@ def __next__(self): return k, dict.__getitem__(self._mapping, k) -class MappingItemsWrapper: +class MappingItemsWrapper(ThunderInterpreterObject): def __init__(self, mapping): self._mapping = mapping @@ -2476,7 +2482,7 @@ class MutMappingWrapperMethods(WrappedValue): def __new__(cls, /, *args, **kwds): uvalue = unwrap(cls)() # todo: for subclasses, better record the call to the constructor - return wrap_const(uvalue) + return wrap(uvalue, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[cls.provenance])) def __init__(self, *other, **kwds): MutMappingWrapperMethods.update(self, *other, **kwds) @@ -2775,7 +2781,6 @@ def _type_call_lookaside(wrapped_typ, *args, **kwargs): obj = _interpret_call(typ.__new__, wrapped_typ, *args, **kwargs) if obj is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return obj - wrapped_init = _interpret_call(getattr, obj, wrap_const("__init__")) assert not isinstance(wrapped_init, INTERPRETER_SIGNALS) populate_attribute_wrapper(wrapped_init, "__self__", obj) @@ -7151,6 +7156,7 @@ def interpret( callbacks: dict[INTERPRETER_CALLBACKS, Callable] = default_callbacks, debug_log: None | StringIO = None, with_provenance_tracking: bool = False, + unwrap_result: bool = True, uncacheable_classes: list[type] | None = None, record_history: bool = False, ) -> Callable: @@ -7205,7 +7211,8 @@ def fn_2(args, kwargs): populate_attribute_wrapper(wrapped_cell, "cell_contents", fn_wrapped) interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs) - interpretation_result = unwrap(interpretation_result) + if unwrap_result: + interpretation_result = unwrap(interpretation_result) except BaseException as e: # TODO Highlight the portion of the line that originated the opcode on Python versions that include diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index 0112ff3faa..d348ea842f 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -94,6 +94,7 @@ PseudoInst, ProvenanceRecord, interpreter_needs_wrap, + ThunderInterpreterObject, ) from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language from thunder.core.baseutils import extract_callable_name @@ -350,7 +351,7 @@ def proxify(self, value: WrappedValue) -> Any: ) return proxy_s else: - raise ValueError("cannot proxify value of {type(uvalue).__type} objects") + raise ValueError(f"cannot proxify value of {type(uvalue).__type__} objects") _jit_ctx = contextvars.ContextVar("jitctx") @@ -445,6 +446,21 @@ def _general_jit_getattr_lookaside(obj: Any, name: str, *maybe_default: Any): getattr_lookaside = default_lookaside(getattr) assert getattr_lookaside is not None + uobj = unwrap(obj) + uname = unwrap(name) + if isinstance(uobj, AnyProxy): + if uname == "__dict__": + return wrap( + obj.original_value.__dict__, + provenance=ProvenanceRecord( + PseudoInst.LOAD_ATTR, + inputs=[ + obj.provenance, + name.provenance, + ], + ), + ) + value = getattr_lookaside(obj, name, *maybe_default) if value is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return value @@ -510,6 +526,53 @@ def _general_jit_ordered_dict_setitem(d, key, value): return dict_setitem_lookaside(d, key, value) +_TORCH_DYNAMIC_TYPES = { + torch.amp.autocast_mode.autocast, + torch.autograd.grad_mode.set_grad_enabled, + torch.autograd.grad_mode.no_grad, +} + + +def is_created_during_tracing(provenance): + if ( + provenance.inst is PseudoInst.OPAQUE + and provenance.inputs[0].inst is PseudoInst.CONSTANT + and provenance.inputs[0].value == object.__new__ + ): + return True + if provenance.inst is PseudoInst.NEW: + return True + return False + + +@interpreter_needs_wrap +def _raw_object_setattr(obj: Any, name: str, value: Any): + return object.__setattr__(obj, name, value) + + +@register_general_jit_lookaside(object.__setattr__) +def _general_jit_object_setattr_lookaside(obj: Any, name: str, value: Any): + uobj = unwrap(obj) + if is_created_during_tracing(obj.provenance) or type(uobj) in _TORCH_DYNAMIC_TYPES: + return _raw_object_setattr(obj, name, value) + + if should_register_for_prologue(obj.provenance) and (obj.original_value is obj.nothing): + if getattr(obj.provenance, "proxy", None) is None: + p: AnyProxy = AnyProxy(uobj, history=obj.provenance) + obj.provenance.proxy = p + obj.register_proxy(p) + uobj = p + + d = _interpret_call(getattr, obj, wrap_const("__dict__")) + if d is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return d + d.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT + ud = unwrap(d) + assert type(ud) == dict + res = _interpret_call(ud.__setitem__, name, value) + return res + + @register_general_jit_lookaside(setattr) def _general_jit_setattr_lookaside(obj: Any, name: str, value: Any): setattr_lookaside = default_lookaside(setattr) @@ -517,9 +580,12 @@ def _general_jit_setattr_lookaside(obj: Any, name: str, value: Any): uobj = unwrap(obj) uname = unwrap(name) + if isinstance(uobj, torch.nn.Module): - # 1) modify the inner thing - # 2) divert the actual setattr... + # 1) populate the wrappeers for the member dicts + # 2) let the original setattr do it's thing by modifying the + # the member dict + # This might generalize to other things, too... for n in MODULE_MEMBER_DICT_ATTRS: member_dict = _interpret_call(getattr, obj, wrap_const(n)) member_dict.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT @@ -682,9 +748,12 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar custom_autograd_function_cls = unwrap(obj) custom_forward = custom_autograd_function_cls.forward - ctx = torch.autograd.function.FunctionCtx() + typ = torch.autograd.function.FunctionCtx + ctx = typ() ctx_proxy = proxy(ctx, name=None, history=None) - wrapped_ctx = wrap_const(ctx_proxy) + wrapped_ctx = wrap_const( + ctx_proxy, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[wrap_const(typ).provenance]) + ) trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace( custom_forward, True, wrapped_ctx, *args, **kwargs ) @@ -1241,10 +1310,10 @@ def _general_jit_global_callback(orig_value: Any, name: str) -> Any: return orig_value -_safe_provenance_inst = { +_input_provenance_inst = { "INPUT_ARGS", "INPUT_KWARGS", - "INPUT_FN", + "INPUT_FN", # or self "LOAD_ATTR", "CONSTANT", "BINARY_SUBSCR", @@ -1259,7 +1328,7 @@ def should_register_for_prologue(pr): inst = inst.opname else: inst = inst.value - if inst not in _safe_provenance_inst: + if inst not in _input_provenance_inst: return False if inst == "CONSTANT" and callable(pr.value): if pr.value.__name__ != "__getitem__" and pr.value != GetSetDescriptorType.__get__: @@ -1509,6 +1578,7 @@ def from_load_attr(provenance, *, new_output=False): output = Proxy(prefix="obj") else: output = p + param_ordering[id(output)] = (output, param_ordering[id(orig_obj)][1] + [math.inf, "." + str(name)]) bsym = prims.unpack_attr.bind(obj, name, output=output) prologue_trace.bound_symbols.append(bsym) @@ -1766,27 +1836,43 @@ def process_recorded_modifications(ctx, epilogue_trace): for k, (inst, *args) in last_modification.items(): if inst == PseudoInst.STORE_SUBSCR: (value,) = args - assert isinstance(value.value, Proxy) - assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR - assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT - assert modified_object.provenance.inputs[1].value == "_buffers" - - typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root( - modified_object.provenance.inputs[0] - ) - assert typ == "_modules" - root_module_proxy = root_for_provenances.get(root_module_provenance) - if root_module_proxy is None: - ## we want this to created in the compute trace context for namespace... - root_module_proxy = Proxy(history=root_module_provenance) - epilogue_trace.add_name(root_module_proxy.name) - root_for_provenances[root_module_provenance] = root_module_proxy - - name = ".".join(name + [k]) - with tracectx(epilogue_trace): - bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None) - epilogue_trace.bound_symbols.append(bsym) + assert isinstance(value.value, (Proxy, int, tuple)) ## todo: better criterion + + if ( + modified_object.provenance.inst is PseudoInst.LOAD_ATTR + and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT + and modified_object.provenance.inputs[1].value == "_buffers" + ): + typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root( + modified_object.provenance.inputs[0] + ) + assert typ == "_modules" + root_module_proxy = root_for_provenances.get(root_module_provenance) + if root_module_proxy is None: + ## we want this to created in the compute trace context for namespace... + root_module_proxy = Proxy(history=root_module_provenance) + epilogue_trace.add_name(root_module_proxy.name) + root_for_provenances[root_module_provenance] = root_module_proxy + + name = ".".join(name + [k]) + with tracectx(epilogue_trace): + bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None) + epilogue_trace.bound_symbols.append(bsym) + elif ( + modified_object.provenance.inst is PseudoInst.LOAD_ATTR + and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT + and modified_object.provenance.inputs[1].value == "__dict__" + ): + name = k + setattr_obj_provenance = modified_object.provenance.inputs[0] + if hasattr(setattr_obj_provenance, "proxy"): + setattr_obj_proxy = setattr_obj_provenance.proxy + with tracectx(epilogue_trace): + bsym = prims.pack_attr.bind(setattr_obj_proxy, name, value.value, output=None) + epilogue_trace.bound_symbols.append(bsym) + else: + raise NotImplementedError(f"Modifications of {modified_object.provenance} are not supported") else: raise NotImplementedError(f"Modifications {inst} on dicts are not supported") else: @@ -1882,6 +1968,7 @@ def thunder_general_jit( fn_lookaside=general_jit_lookaside, callbacks=general_jit_callbacks, with_provenance_tracking=True, + unwrap_result=False, uncacheable_classes=(torch.Tensor, int, float, str, NoneType), record_history=compile_data.debug_options.record_interpreter_history, ) @@ -1891,11 +1978,12 @@ def thunder_general_jit( result = jfn(*args, **kwargs) computation_trace.set_current_source_location(None, None) process_recorded_modifications(ctx, epilogue_trace) + uresult = unwrap(result) last_interpreter_log = jfn._last_interpreter_log - result_proxies = tuple(p for p in tree_iter(result) if isinstance(p, (TensorProxy, NumberProxy))) + result_proxies = tuple(p for p in tree_iter(uresult) if isinstance(p, (TensorProxy, NumberProxy))) prims.python_return(result_proxies) with tracectx(epilogue_trace): - prims.python_return(result) + prims.python_return(uresult) pro_to_comp, pro_to_comp_set, computation_intermediates = get_computation_inputs_and_intermediates( computation_trace @@ -1958,5 +2046,4 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]: epilogue_trace = _apply_trace_proxy_rename( epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue" ) - return TraceResults(prologue_trace, computation_trace, epilogue_trace, last_interpreter_log) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 790be71cb8..682e81acbc 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -128,6 +128,7 @@ class PrimIDs(Enum): UNPACK_THUNDER_MODULE = auto() CONSTRUCT_TUPLE = auto() PACK_BUFFER = auto() + PACK_ATTR = auto() PACK_SETITEM = auto() SHAPE = auto() # TODO: UNPACK_SET @@ -1206,6 +1207,48 @@ def pack_buffer_impl(o: Any, key: Any, v: Any) -> None: ) +# NOTE PACK_ATTR is intended only to be bound to directly, and not called +def pack_attr_meta(o: Any, key: Any, value: Any) -> Any: + raise NotImplementedError + + +def pack_attr_printer( + bsym: BoundSymbol, out_printables: Any, arg_printables: Sequence[Printable], kwarg_printables: dict[str, Printable] +): + utils.check( + len(arg_printables) == 3, + lambda: f"Expected three arguments for pack_attr but got {arg_printables}", + exception_type=AssertionError, + ) + utils.check( + len(kwarg_printables) == 0, + lambda: f"Expected no kwargs for pack_attr but got {kwarg_printables}", + exception_type=AssertionError, + ) + + # Converts printables to strings + obj, key, value = arg_printables + obj_str = codeutils.prettyprint(obj) + key_str = key + value_str = codeutils.prettyprint(value) + return f"{obj_str}.{key_str} = {value_str}" + + +def pack_attr_impl(o: Any, key: Any, v: Any) -> None: + o[key] = v + return None + + +pack_attr = make_prim( + PrimIDs.PACK_ATTR, + "pack_attr", + meta=pack_attr_meta, + python_printer=pack_attr_printer, + python_impl=pack_attr_impl, + tags=(OpTags.DONT_DCE,), +) + + # NOTE PACK_SETITEM is intended only to be bound to directly, and not called def pack_setitem_meta(o: Any, key: Any, value: Any) -> Any: raise NotImplementedError