Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions thunder/core/codeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,13 @@ def _generate_dataclass_class_name(x: object):
# x is an instance of a Dataclass.
# We generate a name for the Dataclass based on the package name and class name so that trace won't have problem
# if there are conflicting names.

assert dataclasses.is_dataclass(x)
name = (x.__class__.__module__ + "_" + x.__class__.__qualname__).replace(".", "_")
if isinstance(x, type):
cls = x
else:
cls = x.__class__
name = (cls.__module__ + "_" + cls.__qualname__).replace(".", "_")
# Class could be a local class in which case it will have `<locals>` in it's module name.
name = name.replace(">", "_").replace("<", "_")
return name
Expand Down Expand Up @@ -149,7 +154,11 @@ def to_printable(

if dataclasses.is_dataclass(x):
# Add `class` to the object_ctx so that we can reuse it during the trace execution.
object_ctx[_generate_dataclass_class_name(x)] = x.__class__
if isinstance(x, type): # dataclass type
cls = x
else: # dataclass type instance
cls = x.__class__
object_ctx[_generate_dataclass_class_name(x)] = cls
# Return the instance as printable object (as function `prettyprint` knows how to deal with it).
return x

Expand Down
27 changes: 17 additions & 10 deletions thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,7 +932,9 @@ class PseudoInst(str, enum.Enum):
SUPER = "SUPER"
BUILTINS = "BUILTINS"
STORE_SUBSCR = "STORE_SUBSCR"
STORE_ATTR = "STORE_ATTR"
LIST_TO_TUPLE = "LIST_TO_TUPLE"
NEW = "NEW"


@dataclasses.dataclass
Expand Down Expand Up @@ -2073,9 +2075,13 @@ def impl(fn, iterable, initializer, null):
return _interpret_call(impl, fn, iterable, initializer, null)


class ThunderInterpreterObject:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a docstring here

pass


# An iterator to be returned from Sequence.__iter__ lookasides below. This will be run in the interpreter
# Note: this potentially might imitate a list_iterator / tuple_iterator more...
class SequenceIter:
class SequenceIter(ThunderInterpreterObject):
def __init__(self, s, is_reversed=False):
self.s = s
self.next_pos = 0 if not is_reversed else len(s) - 1
Expand Down Expand Up @@ -2377,7 +2383,7 @@ def reverse(self, /):
return wrap_const(None)


class MappingKeysIterator(Iterator):
class MappingKeysIterator(Iterator, ThunderInterpreterObject):
# note: the __init__ will be executed by Python itself, and
# the caller needs to set up the wrapped_attribute for _mapping
# The other methods are called through the interpreter mechanism.
Expand All @@ -2395,7 +2401,7 @@ def __next__(self):
return k


class MappingKeysView:
class MappingKeysView(ThunderInterpreterObject):
def __init__(self, mapping):
self._mapping = mapping

Expand Down Expand Up @@ -2425,7 +2431,7 @@ def __reversed__(self):
return mapping_iter


class MappingValuesIterator:
class MappingValuesIterator(ThunderInterpreterObject):
def __init__(self, mapping, is_reversed=False):
self._mapping = mapping
if is_reversed:
Expand All @@ -2440,15 +2446,15 @@ def __next__(self):
return dict.__getitem__(self._mapping, next(self._key_iter))


class MappingValuesWrapper:
class MappingValuesWrapper(ThunderInterpreterObject):
def __init__(self, mapping):
self._mapping = mapping

def __iter__(self):
return MappingValuesIterator(self._mapping)


class MappingItemsIterator:
class MappingItemsIterator(ThunderInterpreterObject):
def __init__(self, mapping, is_reversed=False):
self._mapping = mapping
if is_reversed:
Expand All @@ -2464,7 +2470,7 @@ def __next__(self):
return k, dict.__getitem__(self._mapping, k)


class MappingItemsWrapper:
class MappingItemsWrapper(ThunderInterpreterObject):
def __init__(self, mapping):
self._mapping = mapping

Expand All @@ -2476,7 +2482,7 @@ class MutMappingWrapperMethods(WrappedValue):
def __new__(cls, /, *args, **kwds):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be kwargs

uvalue = unwrap(cls)()
# todo: for subclasses, better record the call to the constructor
return wrap_const(uvalue)
return wrap(uvalue, provenance=ProvenanceRecord(PseudoInst.NEW, inputs=[cls.provenance]))

def __init__(self, *other, **kwds):
MutMappingWrapperMethods.update(self, *other, **kwds)
Expand Down Expand Up @@ -2775,7 +2781,6 @@ def _type_call_lookaside(wrapped_typ, *args, **kwargs):
obj = _interpret_call(typ.__new__, wrapped_typ, *args, **kwargs)
if obj is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return obj

wrapped_init = _interpret_call(getattr, obj, wrap_const("__init__"))
assert not isinstance(wrapped_init, INTERPRETER_SIGNALS)
populate_attribute_wrapper(wrapped_init, "__self__", obj)
Expand Down Expand Up @@ -7151,6 +7156,7 @@ def interpret(
callbacks: dict[INTERPRETER_CALLBACKS, Callable] = default_callbacks,
debug_log: None | StringIO = None,
with_provenance_tracking: bool = False,
unwrap_result: bool = True,
uncacheable_classes: list[type] | None = None,
record_history: bool = False,
) -> Callable:
Expand Down Expand Up @@ -7205,7 +7211,8 @@ def fn_2(args, kwargs):
populate_attribute_wrapper(wrapped_cell, "cell_contents", fn_wrapped)

interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)
interpretation_result = unwrap(interpretation_result)
if unwrap_result:
interpretation_result = unwrap(interpretation_result)

except BaseException as e:
# TODO Highlight the portion of the line that originated the opcode on Python versions that include
Expand Down
Loading