Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
more general setattr
Browse files Browse the repository at this point in the history
t-vi committed Jan 14, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 780407d commit ee736c1
Showing 3 changed files with 178 additions and 41 deletions.
27 changes: 17 additions & 10 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
@@ -932,7 +932,9 @@ class PseudoInst(str, enum.Enum):
SUPER = "SUPER"
BUILTINS = "BUILTINS"
STORE_SUBSCR = "STORE_SUBSCR"
STORE_ATTR = "STORE_ATTR"
LIST_TO_TUPLE = "LIST_TO_TUPLE"
NEW = "NEW"


@dataclasses.dataclass
@@ -2073,9 +2075,13 @@ def impl(fn, iterable, initializer, null):
return _interpret_call(impl, fn, iterable, initializer, null)


class ThunderInterpreterObject:
pass


# An iterator to be returned from Sequence.__iter__ lookasides below. This will be run in the interpreter
# Note: this potentially might imitate a list_iterator / tuple_iterator more...
class SequenceIter:
class SequenceIter(ThunderInterpreterObject):
def __init__(self, s, is_reversed=False):
self.s = s
self.next_pos = 0 if not is_reversed else len(s) - 1
@@ -2377,7 +2383,7 @@ def reverse(self, /):
return wrap_const(None)


class MappingKeysIterator(Iterator):
class MappingKeysIterator(Iterator, ThunderInterpreterObject):
# note: the __init__ will be executed by Python itself, and
# the caller needs to set up the wrapped_attribute for _mapping
# The other methods are called through the interpreter mechanism.
@@ -2395,7 +2401,7 @@ def __next__(self):
return k


class MappingKeysView:
class MappingKeysView(ThunderInterpreterObject):
def __init__(self, mapping):
self._mapping = mapping

@@ -2425,7 +2431,7 @@ def __reversed__(self):
return mapping_iter


class MappingValuesIterator:
class MappingValuesIterator(ThunderInterpreterObject):
def __init__(self, mapping, is_reversed=False):
self._mapping = mapping
if is_reversed:
@@ -2440,15 +2446,15 @@ def __next__(self):
return dict.__getitem__(self._mapping, next(self._key_iter))


class MappingValuesWrapper:
class MappingValuesWrapper(ThunderInterpreterObject):
def __init__(self, mapping):
self._mapping = mapping

def __iter__(self):
return MappingValuesIterator(self._mapping)


class MappingItemsIterator:
class MappingItemsIterator(ThunderInterpreterObject):
def __init__(self, mapping, is_reversed=False):
self._mapping = mapping
if is_reversed:
@@ -2464,7 +2470,7 @@ def __next__(self):
return k, dict.__getitem__(self._mapping, k)


class MappingItemsWrapper:
class MappingItemsWrapper(ThunderInterpreterObject):
def __init__(self, mapping):
self._mapping = mapping

@@ -2476,7 +2482,7 @@ class MutMappingWrapperMethods(WrappedValue):
def __new__(cls, /, *args, **kwds):
uvalue = unwrap(cls)()
# todo: for subclasses, better record the call to the constructor
return wrap_const(uvalue)
return wrap(uvalue, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[cls.provenance]))

def __init__(self, *other, **kwds):
MutMappingWrapperMethods.update(self, *other, **kwds)
@@ -2775,7 +2781,6 @@ def _type_call_lookaside(wrapped_typ, *args, **kwargs):
obj = _interpret_call(typ.__new__, wrapped_typ, *args, **kwargs)
if obj is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return obj

wrapped_init = _interpret_call(getattr, obj, wrap_const("__init__"))
assert not isinstance(wrapped_init, INTERPRETER_SIGNALS)
populate_attribute_wrapper(wrapped_init, "__self__", obj)
@@ -7151,6 +7156,7 @@ def interpret(
callbacks: dict[INTERPRETER_CALLBACKS, Callable] = default_callbacks,
debug_log: None | StringIO = None,
with_provenance_tracking: bool = False,
unwrap_result: bool = True,
uncacheable_classes: list[type] | None = None,
record_history: bool = False,
) -> Callable:
@@ -7205,7 +7211,8 @@ def fn_2(args, kwargs):
populate_attribute_wrapper(wrapped_cell, "cell_contents", fn_wrapped)

interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
interpretation_result = unwrap(interpretation_result)
if unwrap_result:
interpretation_result = unwrap(interpretation_result)

except BaseException as e:
# TODO Highlight the portion of the line that originated the opcode on Python versions that include
149 changes: 118 additions & 31 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
@@ -94,6 +94,7 @@
PseudoInst,
ProvenanceRecord,
interpreter_needs_wrap,
ThunderInterpreterObject,
)
from thunder.core.langctxs import set_langctx, reset_langctx, Languages, resolve_language
from thunder.core.baseutils import extract_callable_name
@@ -350,7 +351,7 @@ def proxify(self, value: WrappedValue) -> Any:
)
return proxy_s
else:
raise ValueError("cannot proxify value of {type(uvalue).__type} objects")
raise ValueError(f"cannot proxify value of {type(uvalue).__type__} objects")


_jit_ctx = contextvars.ContextVar("jitctx")
@@ -445,6 +446,21 @@ def _general_jit_getattr_lookaside(obj: Any, name: str, *maybe_default: Any):
getattr_lookaside = default_lookaside(getattr)
assert getattr_lookaside is not None

uobj = unwrap(obj)
uname = unwrap(name)
if isinstance(uobj, AnyProxy):
if uname == "__dict__":
return wrap(
obj.original_value.__dict__,
provenance=ProvenanceRecord(
PseudoInst.LOAD_ATTR,
inputs=[
obj.provenance,
name.provenance,
],
),
)

value = getattr_lookaside(obj, name, *maybe_default)
if value is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return value
@@ -510,16 +526,66 @@ def _general_jit_ordered_dict_setitem(d, key, value):
return dict_setitem_lookaside(d, key, value)


_TORCH_DYNAMIC_TYPES = {
torch.amp.autocast_mode.autocast,
torch.autograd.grad_mode.set_grad_enabled,
torch.autograd.grad_mode.no_grad,
}


def is_created_during_tracing(provenance):
if (
provenance.inst is PseudoInst.OPAQUE
and provenance.inputs[0].inst is PseudoInst.CONSTANT
and provenance.inputs[0].value == object.__new__
):
return True
if provenance.inst is PseudoInst.NEW:
return True
return False


@interpreter_needs_wrap
def _raw_object_setattr(obj: Any, name: str, value: Any):
return object.__setattr__(obj, name, value)


@register_general_jit_lookaside(object.__setattr__)
def _general_jit_object_setattr_lookaside(obj: Any, name: str, value: Any):
uobj = unwrap(obj)
if is_created_during_tracing(obj.provenance) or type(uobj) in _TORCH_DYNAMIC_TYPES:
return _raw_object_setattr(obj, name, value)

if should_register_for_prologue(obj.provenance) and (obj.original_value is obj.nothing):
if getattr(obj.provenance, "proxy", None) is None:
p: AnyProxy = AnyProxy(uobj, history=obj.provenance)
obj.provenance.proxy = p
obj.register_proxy(p)
uobj = p

d = _interpret_call(getattr, obj, wrap_const("__dict__"))
if d is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return d
d.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
ud = unwrap(d)
assert type(ud) == dict
res = _interpret_call(ud.__setitem__, name, value)
return res


@register_general_jit_lookaside(setattr)
def _general_jit_setattr_lookaside(obj: Any, name: str, value: Any):
setattr_lookaside = default_lookaside(setattr)
assert setattr_lookaside is not None

uobj = unwrap(obj)
uname = unwrap(name)

if isinstance(uobj, torch.nn.Module):
# 1) modify the inner thing
# 2) divert the actual setattr...
# 1) populate the wrappeers for the member dicts
# 2) let the original setattr do it's thing by modifying the
# the member dict
# This might generalize to other things, too...
for n in MODULE_MEMBER_DICT_ATTRS:
member_dict = _interpret_call(getattr, obj, wrap_const(n))
member_dict.provenance.ext_flag |= EXT_FLAG_IS_MODULE_MEMBER_DICT
@@ -682,9 +748,12 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
ctx = torch.autograd.function.FunctionCtx()
typ = torch.autograd.function.FunctionCtx
ctx = typ()
ctx_proxy = proxy(ctx, name=None, history=None)
wrapped_ctx = wrap_const(ctx_proxy)
wrapped_ctx = wrap_const(
ctx_proxy, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[wrap_const(typ).provenance])
)
trace_of_fwd, fwd_output_provenance = _convert_pytorchfunc_to_thundertrace(
custom_forward, True, wrapped_ctx, *args, **kwargs
)
@@ -1241,10 +1310,10 @@ def _general_jit_global_callback(orig_value: Any, name: str) -> Any:
return orig_value


_safe_provenance_inst = {
_input_provenance_inst = {
"INPUT_ARGS",
"INPUT_KWARGS",
"INPUT_FN",
"INPUT_FN", # or self
"LOAD_ATTR",
"CONSTANT",
"BINARY_SUBSCR",
@@ -1259,7 +1328,7 @@ def should_register_for_prologue(pr):
inst = inst.opname
else:
inst = inst.value
if inst not in _safe_provenance_inst:
if inst not in _input_provenance_inst:
return False
if inst == "CONSTANT" and callable(pr.value):
if pr.value.__name__ != "__getitem__" and pr.value != GetSetDescriptorType.__get__:
@@ -1509,6 +1578,7 @@ def from_load_attr(provenance, *, new_output=False):
output = Proxy(prefix="obj")
else:
output = p

param_ordering[id(output)] = (output, param_ordering[id(orig_obj)][1] + [math.inf, "." + str(name)])
bsym = prims.unpack_attr.bind(obj, name, output=output)
prologue_trace.bound_symbols.append(bsym)
@@ -1766,27 +1836,43 @@ def process_recorded_modifications(ctx, epilogue_trace):
for k, (inst, *args) in last_modification.items():
if inst == PseudoInst.STORE_SUBSCR:
(value,) = args
assert isinstance(value.value, Proxy)

assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR
assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
assert modified_object.provenance.inputs[1].value == "_buffers"

typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
root_module_proxy = root_for_provenances.get(root_module_provenance)
if root_module_proxy is None:
## we want this to created in the compute trace context for namespace...
root_module_proxy = Proxy(history=root_module_provenance)
epilogue_trace.add_name(root_module_proxy.name)
root_for_provenances[root_module_provenance] = root_module_proxy

name = ".".join(name + [k])
with tracectx(epilogue_trace):
bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None)
epilogue_trace.bound_symbols.append(bsym)
assert isinstance(value.value, (Proxy, int, tuple)) ## todo: better criterion

if (
modified_object.provenance.inst is PseudoInst.LOAD_ATTR
and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
and modified_object.provenance.inputs[1].value == "_buffers"
):
typ, name, root_module_provenance = get_parameter_or_buffer_or_submodule_name_and_root(
modified_object.provenance.inputs[0]
)
assert typ == "_modules"
root_module_proxy = root_for_provenances.get(root_module_provenance)
if root_module_proxy is None:
## we want this to created in the compute trace context for namespace...
root_module_proxy = Proxy(history=root_module_provenance)
epilogue_trace.add_name(root_module_proxy.name)
root_for_provenances[root_module_provenance] = root_module_proxy

name = ".".join(name + [k])
with tracectx(epilogue_trace):
bsym = prims.pack_buffer.bind(root_module_proxy, name, value.value, output=None)
epilogue_trace.bound_symbols.append(bsym)
elif (
modified_object.provenance.inst is PseudoInst.LOAD_ATTR
and modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT
and modified_object.provenance.inputs[1].value == "__dict__"
):
name = k
setattr_obj_provenance = modified_object.provenance.inputs[0]
if hasattr(setattr_obj_provenance, "proxy"):
setattr_obj_proxy = setattr_obj_provenance.proxy
with tracectx(epilogue_trace):
bsym = prims.pack_attr.bind(setattr_obj_proxy, name, value.value, output=None)
epilogue_trace.bound_symbols.append(bsym)
else:
raise NotImplementedError(f"Modifications of {modified_object.provenance} are not supported")
else:
raise NotImplementedError(f"Modifications {inst} on dicts are not supported")
else:
@@ -1882,6 +1968,7 @@ def thunder_general_jit(
fn_lookaside=general_jit_lookaside,
callbacks=general_jit_callbacks,
with_provenance_tracking=True,
unwrap_result=False,
uncacheable_classes=(torch.Tensor, int, float, str, NoneType),
record_history=compile_data.debug_options.record_interpreter_history,
)
@@ -1891,11 +1978,12 @@ def thunder_general_jit(
result = jfn(*args, **kwargs)
computation_trace.set_current_source_location(None, None)
process_recorded_modifications(ctx, epilogue_trace)
uresult = unwrap(result)
last_interpreter_log = jfn._last_interpreter_log
result_proxies = tuple(p for p in tree_iter(result) if isinstance(p, (TensorProxy, NumberProxy)))
result_proxies = tuple(p for p in tree_iter(uresult) if isinstance(p, (TensorProxy, NumberProxy)))
prims.python_return(result_proxies)
with tracectx(epilogue_trace):
prims.python_return(result)
prims.python_return(uresult)

pro_to_comp, pro_to_comp_set, computation_intermediates = get_computation_inputs_and_intermediates(
computation_trace
@@ -1958,5 +2046,4 @@ def restrict_proxy_swapmap(proxies: tuple[Proxy]) -> dict[Variable, Proxy]:
epilogue_trace = _apply_trace_proxy_rename(
epilogue_trace, restrict_proxy_swapmap(pro_to_epi_proxies + comp_to_epi_proxies), "epilogue"
)

return TraceResults(prologue_trace, computation_trace, epilogue_trace, last_interpreter_log)
43 changes: 43 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
@@ -128,6 +128,7 @@ class PrimIDs(Enum):
UNPACK_THUNDER_MODULE = auto()
CONSTRUCT_TUPLE = auto()
PACK_BUFFER = auto()
PACK_ATTR = auto()
PACK_SETITEM = auto()
SHAPE = auto()
# TODO: UNPACK_SET
@@ -1206,6 +1207,48 @@ def pack_buffer_impl(o: Any, key: Any, v: Any) -> None:
)


# NOTE PACK_ATTR is intended only to be bound to directly, and not called
def pack_attr_meta(o: Any, key: Any, value: Any) -> Any:
raise NotImplementedError


def pack_attr_printer(
bsym: BoundSymbol, out_printables: Any, arg_printables: Sequence[Printable], kwarg_printables: dict[str, Printable]
):
utils.check(
len(arg_printables) == 3,
lambda: f"Expected three arguments for pack_attr but got {arg_printables}",
exception_type=AssertionError,
)
utils.check(
len(kwarg_printables) == 0,
lambda: f"Expected no kwargs for pack_attr but got {kwarg_printables}",
exception_type=AssertionError,
)

# Converts printables to strings
obj, key, value = arg_printables
obj_str = codeutils.prettyprint(obj)
key_str = key
value_str = codeutils.prettyprint(value)
return f"{obj_str}.{key_str} = {value_str}"


def pack_attr_impl(o: Any, key: Any, v: Any) -> None:
o[key] = v
return None


pack_attr = make_prim(
PrimIDs.PACK_ATTR,
"pack_attr",
meta=pack_attr_meta,
python_printer=pack_attr_printer,
python_impl=pack_attr_impl,
tags=(OpTags.DONT_DCE,),
)


# NOTE PACK_SETITEM is intended only to be bound to directly, and not called
def pack_setitem_meta(o: Any, key: Any, value: Any) -> Any:
raise NotImplementedError

0 comments on commit ee736c1

Please sign in to comment.