diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 81929293ef..450359f982 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -31,6 +31,7 @@ jobs: requires: ["latest", "nightly"] # , 'oldest' include: - { os: "ubuntu-22.04", python-version: "3.11", requires: "latest" } + - { os: "ubuntu-22.04", python-version: "3.12", requires: "latest" } # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 @@ -104,7 +105,7 @@ jobs: --durations=250 - name: Testing just a few - if: matrix.python-version == '3.11' + if: matrix.python-version == '3.11' || matrix.python-version == '3.12' #continue-on-error: true run: | python -m pytest \ diff --git a/setup.py b/setup.py index 817caa05da..26fed14342 100755 --- a/setup.py +++ b/setup.py @@ -113,7 +113,7 @@ def _load_readme_description(path_dir: str, homepage: str, version: str) -> str: include_package_data=True, zip_safe=False, keywords=["deep learning", "AI"], - python_requires=">=3.10, <3.12", + python_requires=">=3.10, <3.13", setup_requires=["wheel"], install_requires=_load_requirements(_PATH_REQUIRES, file_name="base.txt"), extras_require=_prepare_extras(), diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 8a4f717ed9..c0a24566f3 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -924,6 +924,7 @@ class PseudoInst(str, enum.Enum): SUPER = "SUPER" BUILTINS = "BUILTINS" STORE_SUBSCR = "STORE_SUBSCR" + LIST_TO_TUPLE = "LIST_TO_TUPLE" @dataclasses.dataclass @@ -1133,7 +1134,7 @@ def nexti(self, inst: dis.Instruction): if (3, 9) <= sys.version_info < (3, 11): if inst.starts_line is not None: self.positions = Positions(inst.starts_line, inst.starts_line, 0, 999) - elif (3, 11) <= sys.version_info < (3, 12): + elif (3, 11) <= sys.version_info < (3, 13): if inst.positions is not None: self.positions = inst.positions else: @@ -1238,6 +1239,9 @@ def __call__(self, fn: Callable) -> Callable: ): assert self.name not in _default_opcode_handler_map, self.name assert self.name in dis.opmap, self.name + assert ( + self.name.lower() in fn.__name__ + ), f"opcode handler name mismatch {self.name.lower()} vs. {fn.__name__}" _default_opcode_handler_map[self.name] = fn return fn return _default_opcode_handler_map.get(self.name, fn) @@ -1557,7 +1561,10 @@ def lookup_descriptor_field(field_name): # if it is opaque, don't _interpret_call here, to avoid a wrap/unwrap dance if is_opaque(descr_get): - return descr_get(cls_var, uobj, objtype) + try: + return descr_get(cls_var, uobj, objtype) + except Exception as e: + return do_raise(e) result = _interpret_call_with_unwrapping(descr_get, cls_var, obj, objtype) return result @@ -3103,9 +3110,10 @@ def check_signal(val): # https://docs.python.org/3.11/library/dis.html#opcode-ASYNC_GEN_WRAP -@register_opcode_handler("ASYNC_GEN_WRAP", min_ver=(3, 11)) +@register_opcode_handler("ASYNC_GEN_WRAP", min_ver=(3, 11), max_ver=(3, 11)) def _async_gen_wrap_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: # the next thing will be to yield the value, but we delegate this along with the wrapping to thunder_interpreter_async_generator + # update the intrinsic for 3.12+, too pass @@ -3463,6 +3471,24 @@ def _inplace_xor_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kw return _binary_op_helper(stack, BINARY_OP.IXOR) +# https://docs.python.org/3.12/library/dis.html#opcode-BINARY_SLICE +@register_opcode_handler("BINARY_SLICE", min_ver=(3, 12)) +def _binary_slice_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: + end = stack.pop_wrapped() + start = stack.pop_wrapped() + container = stack.pop_wrapped() + + def impl(container, start, end): + return container.__getitem__(slice(start, end)) + + res = _interpret_call(impl, container, start, end) + + if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return res + + return check_and_append(stack, res) + + # https://docs.python.org/3.10/library/dis.html#opcode-BINARY_SUBSCR @register_opcode_handler("BINARY_SUBSCR") def _binary_subscr_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: @@ -3668,6 +3694,90 @@ def _call_function_kw_handler( return check_and_append(stack, _interpret_call_with_unwrapping(func, *args, **fn_kwargs)) +def _list_to_tuple_intrinsic(tos): + assert wrapped_isinstance(tos, list) + populate_item_wrappers(tos) + + res = tuple(unwrap(tos)) + + ctx: InterpreterCompileCtx = get_interpretercompilectx() + if ctx._with_provenance_tracking: + pr = ProvenanceRecord(PseudoInst.LIST_TO_TUPLE, inputs=[tos.provenance]) + res = wrap(res, provenance=pr) + res.item_wrappers = tos.item_wrappers[:] + + return res + + +def _stopiteration_error_intrinsic(exc): + runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx() + co_flags = runtimectx.frame_stack[-1].code.co_flags + + assert wrapped_isinstance(exc, Exception) + # CPython 3.12 asserts whether frame->owner == FRAME_OWNED_BY_GENERATOR + assert co_flags & (inspect.CO_COROUTINE | inspect.CO_GENERATOR | inspect.CO_ASYNC_GENERATOR) + + msg = None + if wrapped_isinstance(exc, StopIteration): + msg = "generator raised StopIteration" + if co_flags & inspect.CO_ASYNC_GENERATOR: + msg = "async generator raised StopIteration" + elif co_flags & inspect.CO_COROUTINE: + msg = "coroutine raised StopIteration" + elif (co_flags & inspect.CO_ASYNC_GENERATOR) and wrapped_isinstance(exc, StopAsyncIteration): + msg = "async generator raised StopAsyncIteration" + + if msg: + compile_ctx: InterpreterCompileCtx = get_interpretercompilectx() + if compile_ctx._with_provenance_tracking: + msg = wrap_const(msg) + + def impl(exc, msg): + error = RuntimeError(msg) + error.__cause__ = exc.value + return error + + return _interpret_call(impl, exc, msg) + + return exc + + +def _async_gen_wrap_intrinsic(v): + # noop for now + # see ASYNC_GEN_WRAP opcode for 3.11 + return v + + +# https://docs.python.org/3.12/library/dis.html#opcode-CALL_INTRINSIC_1 +@register_opcode_handler("CALL_INTRINSIC_1", min_ver=(3, 12)) +def _call_intrinsic_1_handler( + inst: dis.Instruction, /, stack: InterpreterStack, **kwargs +) -> None | INTERPRETER_SIGNALS: + assert type(inst.arg) is int + intrinsics_1 = { + # INTRINSIC_1_INVALID + "INTRINSIC_PRINT": _print_intrinsic, + "INTRINSIC_LIST_TO_TUPLE": _list_to_tuple_intrinsic, + "INTRINSIC_IMPORT_STAR": _import_star_intrinsic, + "INTRINSIC_STOPITERATION_ERROR": _stopiteration_error_intrinsic, + "INTRINSIC_ASYNC_GEN_WRAP": _async_gen_wrap_intrinsic, + "INTRINSIC_UNARY_POSITIVE": _unary_positive_intrinsic, + # INTRINSIC_TYPEVAR + # INTRINSIC_PARAMSPEC + # INTRINSIC_TYPEVARTUPLE + # INTRINSIC_SUBSCRIPT_GENERIC + # INTRINSIC_TYPEALIAS + } + intrinsic_name = dis._intrinsic_1_descs[inst.arg] + + tos = stack.pop_wrapped() + fn = intrinsics_1.get(intrinsic_name) + if fn is None: + raise NotImplementedError(f"CALL_INTRINSIC_1 {intrinsic_name}") + res = fn(tos) + return check_and_append(stack, res) + + # https://docs.python.org/3.10/library/dis.html#opcode-CALL_METHOD @register_opcode_handler("CALL_METHOD", max_ver=(3, 10)) def _call_method_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: @@ -3747,12 +3857,18 @@ def _compare_op_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwa ">": lambda x, y: x > y, ">=": lambda x, y: x >= y, } + b = stack.pop() a = stack.pop() assert type(inst.arg) is int - assert inst.arg < len(dis.cmp_op), f"{inst}, {dis.cmp_op}" + if sys.version_info >= (3, 12): + # this is not in the dis.dis page... + op_nr = inst.arg >> 4 + else: + op_nr = inst.arg + assert op_nr < len(dis.cmp_op), f"{inst}, {dis.cmp_op}" - op = cmp_impls[dis.cmp_op[inst.arg]] + op = cmp_impls[dis.cmp_op[op_nr]] res: bool = op(unwrap(a), unwrap(b)) stack.append(res) @@ -4016,6 +4132,19 @@ def _end_async_for_handler_3_11( return INTERPRETER_SIGNALS.EXCEPTION_RAISED +# https://docs.python.org/3.11/library/dis.html#opcode-END_FOR +@register_opcode_handler("END_FOR", min_ver=(3, 12)) +def _end_for_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: + stack.pop_wrapped() + stack.pop_wrapped() + + +# https://docs.python.org/3.12/library/dis.html#opcode-END_SEND +@register_opcode_handler("END_SEND", min_ver=(3, 12)) +def _end_send_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: + del stack[-2] + + # https://docs.python.org/3.10/library/dis.html#opcode-EXTENDED_ARG @register_opcode_handler("EXTENDED_ARG") def _extended_arg_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: @@ -4084,8 +4213,14 @@ def _next_impl(tos): if r is INTERPRETER_SIGNALS.EXCEPTION_RAISED: ctx = get_interpreterruntimectx() if isinstance(ctx.curexc, StopIteration): - stack.pop_wrapped() - return inst_ptr + delta + 1 + if sys.version_info >= (3, 12): + # 3.12 uses jumps relative to the next instruction offset and does not pop here + # instead it pushes a fake value?! + stack.append(Py_NULL()) + return delta + else: + stack.pop_wrapped() + return inst_ptr + delta + 1 return r stack.append(r) @@ -4259,34 +4394,19 @@ def impl(module_name, fromlist, level): return check_and_append(stack, _interpret_call_with_unwrapping(impl, module_name, fromlist, level)) -# https://docs.python.org/3.10/library/dis.html#opcode-IMPORT_STAR -@register_opcode_handler("IMPORT_STAR") -def _import_star_handler( - inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs -) -> None | INTERPRETER_SIGNALS: - # The module is actually imported from another instruction. - # This instruction can only be parsed at top level and modify globals, - # since localsplus is of fixed length/positions. It can't be parsed inside a function. - - # `from operator import *` compiles as - # 0 LOAD_CONST 0 (0) - # 2 LOAD_CONST 1 (('*',)) - # 4 IMPORT_NAME 0 (operator) - # 6 IMPORT_STAR - - module = stack.pop() - assert isinstance(module, ModuleType) +def _import_star_intrinsic(module): + assert wrapped_isinstance(module, ModuleType) # Get the locals of the current frame, not the frame created by interpreted impl() below. - _locals = _interpret_call_with_unwrapping(locals) + _locals = _interpret_call(locals) if _locals is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return _locals - assert isinstance(_locals, dict) + assert wrapped_isinstance(_locals, dict) # For every name in __all__ if present in the module, or every name in __dict__ not # starting with _ if __all__ is not present, add the name to the current locals() dict, # and produce the same exceptions as cpython would. - def impl(): + def impl(module, _locals): skip_leading_underscores = False all_names = getattr(module, "__all__", None) if all_names is None: @@ -4310,7 +4430,27 @@ def impl(): continue _locals[name] = getattr(module, name) - res = _interpret_call_with_unwrapping(impl) + return _interpret_call(impl, module, _locals) + + +# https://docs.python.org/3.10/library/dis.html#opcode-IMPORT_STAR +@register_opcode_handler("IMPORT_STAR", max_ver=(3, 11)) +def _import_star_handler( + inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs +) -> None | INTERPRETER_SIGNALS: + # The module is actually imported from another instruction. + # This instruction can only be parsed at top level and modify globals, + # since localsplus is of fixed length/positions. It can't be parsed inside a function. + + # `from operator import *` compiles as + # 0 LOAD_CONST 0 (0) + # 2 LOAD_CONST 1 (('*',)) + # 4 IMPORT_NAME 0 (operator) + # 6 IMPORT_STAR + + module = stack.pop_wrapped() + + res = _import_star_intrinsic(module) if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return res @@ -4338,6 +4478,8 @@ def _jump_absolute_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> def _jump_forward_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> int: assert type(inst.arg) is int delta: int = inst.arg + if sys.version_info >= (3, 12): + return delta return inst_ptr + delta + 1 @@ -4346,6 +4488,8 @@ def _jump_forward_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> def _jump_backward_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> int: assert type(inst.arg) is int delta: int = inst.arg + if sys.version_info >= (3, 12): + return -delta return inst_ptr - delta + 1 @@ -4354,6 +4498,8 @@ def _jump_backward_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> def _jump_backward_no_interrupt_handler(inst: dis.Instruction, /, inst_ptr: int, **kwargs) -> int: assert type(inst.arg) is int delta: int = inst.arg + if sys.version_info >= (3, 12): + return -delta return inst_ptr - delta + 1 @@ -4376,7 +4522,7 @@ def _jump_if_not_exc_match_handler( # https://docs.python.org/3.10/library/dis.html#opcode-JUMP_TRUE_OR_POP # https://docs.python.org/3.11/library/dis.html#opcode-JUMP_TRUE_OR_POP -@register_opcode_handler("JUMP_IF_TRUE_OR_POP") +@register_opcode_handler("JUMP_IF_TRUE_OR_POP", max_ver=(3, 11)) def _jump_if_true_or_pop_handler( inst: dis.Instruction, /, inst_ptr: int, stack: InterpreterStack, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -4398,9 +4544,9 @@ def _jump_if_true_or_pop_handler( return target -# https://docs.python.org/3.10/library/dis.html#opcode-JUMP_FALSE_OR_POP -# https://docs.python.org/3.11/library/dis.html#opcode-JUMP_FALSE_OR_POP -@register_opcode_handler("JUMP_IF_FALSE_OR_POP") +# https://docs.python.org/3.10/library/dis.html#opcode-JUMP_IF_FALSE_OR_POP +# https://docs.python.org/3.11/library/dis.html#opcode-JUMP_IF_FALSE_OR_POP +@register_opcode_handler("JUMP_IF_FALSE_OR_POP", max_ver=(3, 11)) def _jump_if_false_or_pop_handler( inst: dis.Instruction, /, inst_ptr: int, stack: InterpreterStack, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -4461,31 +4607,20 @@ def _list_extend_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kw # https://docs.python.org/3.10/library/dis.html#opcode-LIST_TO_TUPLE -@register_opcode_handler("LIST_TO_TUPLE") +@register_opcode_handler("LIST_TO_TUPLE", max_ver=(3, 11)) def _list_to_tuple_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: tos = stack.pop_wrapped() - assert wrapped_isinstance(tos, list) - populate_item_wrappers(tos) - - res = tuple(unwrap(tos)) - - ctx: InterpreterCompileCtx = get_interpretercompilectx() - if ctx._with_provenance_tracking: - pr = ProvenanceRecord(inst, inputs=[tos.provenance]) - res = wrap(res, provenance=pr) - res.item_wrappers = tos.item_wrappers[:] - - stack.append(res) + return check_and_append(stack, _list_to_tuple_intrinsic(tos)) # https://docs.python.org/3.13/library/dis.html#opcode-LOAD_ASSERTION_ERROR @register_opcode_handler("LOAD_ASSERTION_ERROR") -def _load_assertion_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: +def _load_assertion_error_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None: stack.append(wrap_const(AssertionError)) # https://docs.python.org/3.10/library/dis.html#opcode-LOAD_ATTR -@register_opcode_handler("LOAD_ATTR") +@register_opcode_handler("LOAD_ATTR", max_ver=(3, 11)) def _load_attr_handler( inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs ) -> None | INTERPRETER_SIGNALS: @@ -4497,6 +4632,51 @@ def _load_attr_handler( return check_and_append(stack, _interpret_call(getattr, a, name)) +# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_ATTR +@register_opcode_handler("LOAD_ATTR", min_ver=(3, 12)) +def _load_attr_handler_3_12( + inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs +) -> None | INTERPRETER_SIGNALS: + assert type(inst.arg) is int + idx = inst.arg >> 1 + load_method_like = inst.arg & 1 + obj = stack.pop_wrapped() + name: WrappedValue = wrap_const(co.co_names[idx]) + if not load_method_like: + return check_and_append(stack, _interpret_call(getattr, obj, name)) + else: + return load_method_helper(obj, name, stack) + + +# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_ATTR +@register_opcode_handler("LOAD_SUPER_ATTR", min_ver=(3, 12)) +def _load_super_attr_handler( + inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs +) -> None | INTERPRETER_SIGNALS: + assert type(inst.arg) is int + idx = inst.arg >> 2 + load_method_like = (inst.arg & 1) > 0 + two_argument_super = (inst.arg & 2) > 0 + + _self = stack.pop_wrapped() + _cls = stack.pop_wrapped() + _super = stack.pop_wrapped() # ??? + + if two_argument_super: + super_object = _interpret_call(_super, _cls, _self) + else: + super_object = _interpret_call(_super) + + if super_object is INTERPRETER_SIGNALS.EXCEPTION_RAISED: + return super_object + + name: WrappedValue = wrap_const(co.co_names[idx]) + if not load_method_like: + return check_and_append(stack, _interpret_call(getattr, super_object, name)) + else: + return load_method_helper(super_object, name, stack) + + # https://docs.python.org/3.10/library/dis.html#opcode-LOAD_BUILD_CLASS @register_opcode_handler("LOAD_BUILD_CLASS") def _load_build_class_handler( @@ -4570,8 +4750,11 @@ def _load_deref_handler( # https://docs.python.org/3.10/library/dis.html#opcode-LOAD_FAST +# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_FAST_CHECK +# LOAD_FAST for Python <3.12 is LOAD_FAST_CHECK +@register_opcode_handler("LOAD_FAST_CHECK", min_ver=(3, 12)) @register_opcode_handler("LOAD_FAST") -def _load_fast_handler( +def _load_fast_check_handler( inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs ) -> None | INTERPRETER_SIGNALS: assert isinstance(inst.arg, int) @@ -4594,6 +4777,34 @@ def _load_fast_handler( return check_and_append(stack, val) +# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_FAST_AND_CLEAR +@register_opcode_handler("LOAD_FAST_AND_CLEAR", min_ver=(3, 12)) +def _load_fast_and_clear_handler( + inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs +) -> None | INTERPRETER_SIGNALS: + assert isinstance(inst.arg, int) + var_num: int = inst.arg + assert var_num >= 0 and var_num < len(frame.localsplus) + + val: Any = frame.localsplus[var_num] + name: str = frame.get_localsplus_name(var_num) + + # empty local variable slots are initialized to Py_NULL(), in this + # case we push Py_NULL (but wrapped) + if isinstance(val, Py_NULL): + val = wrap_const(Py_NULL()) + + ctx: InterpreterCompileCtx = get_interpretercompilectx() + if ctx._with_provenance_tracking: + assert isinstance(val, WrappedValue), f"unexpected value of type {type(val)}, {val}, {inst}" + + val = load_fast_callback(val, name) + # clear the local variable + frame.localsplus[var_num] = Py_NULL() + + return check_and_append(stack, val) + + # https://docs.python.org/3.10/library/dis.html#opcode-LOAD_GLOBAL @register_opcode_handler("LOAD_GLOBAL") def _load_global_handler( @@ -4632,9 +4843,12 @@ def _load_method_handler( inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, **kwargs ) -> None | INTERPRETER_SIGNALS: assert type(inst.arg) is int - name = wrap_const(co.co_names[inst.arg]) obj = stack.pop_wrapped() + name = wrap_const(co.co_names[inst.arg]) + return load_method_helper(obj, name, stack) + +def load_method_helper(obj, name, stack): meth = _interpret_call(getattr, obj, name) if meth is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return meth @@ -4959,7 +5173,7 @@ def _pop_except_handler_3_11( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_BACKWARD_IF_FALSE -@register_opcode_handler("POP_JUMP_BACKWARD_IF_FALSE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_BACKWARD_IF_FALSE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_backward_if_false_handler( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -4978,7 +5192,7 @@ def _pop_jump_backward_if_false_handler( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_BACKWARD_IF_NONE -@register_opcode_handler("POP_JUMP_BACKWARD_IF_NONE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_BACKWARD_IF_NONE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_backward_if_none_handler( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None: @@ -4993,7 +5207,7 @@ def _pop_jump_backward_if_none_handler( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_BACKWARD_IF_NOT_NONE -@register_opcode_handler("POP_JUMP_BACKWARD_IF_NOT_NONE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_BACKWARD_IF_NOT_NONE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_backward_if_not_none_handler( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None: @@ -5008,7 +5222,7 @@ def _pop_jump_backward_if_not_none_handler( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_BACKWARD_IF_TRUE -@register_opcode_handler("POP_JUMP_BACKWARD_IF_TRUE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_BACKWARD_IF_TRUE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_backward_if_true_handler( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -5027,7 +5241,7 @@ def _pop_jump_backward_if_true_handler( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_FORWARD_IF_FALSE -@register_opcode_handler("POP_JUMP_FORWARD_IF_FALSE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_FORWARD_IF_FALSE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_forward_if_false_handler( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -5046,7 +5260,7 @@ def _pop_jump_forward_if_false_handler( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_FORWARD_IF_TRUE -@register_opcode_handler("POP_JUMP_FORWARD_IF_TRUE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_FORWARD_IF_TRUE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_forward_if_true_handler_3_11( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -5065,7 +5279,7 @@ def _pop_jump_forward_if_true_handler_3_11( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_FORWARD_IF_NONE -@register_opcode_handler("POP_JUMP_FORWARD_IF_NONE", min_ver=(3, 11)) +@register_opcode_handler("POP_JUMP_FORWARD_IF_NONE", min_ver=(3, 11), max_ver=(3, 11)) def _pop_jump_forward_if_none_handler_3_11( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None: @@ -5079,8 +5293,8 @@ def _pop_jump_forward_if_none_handler_3_11( # https://docs.python.org/3.11/library/dis.html#opcode-POP_JUMP_FORWARD_IF_NOT_NONE -@register_opcode_handler("POP_JUMP_FORWARD_IF_NOT_NONE", min_ver=(3, 11)) -def _pop_jump_forward_if_none_handler( +@register_opcode_handler("POP_JUMP_FORWARD_IF_NOT_NONE", min_ver=(3, 11), max_ver=(3, 11)) +def _pop_jump_forward_if_not_none_handler( inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs ) -> int | None: assert isinstance(inst.arg, int) @@ -5093,6 +5307,7 @@ def _pop_jump_forward_if_none_handler( @register_opcode_handler("POP_JUMP_IF_FALSE", max_ver=(3, 10)) +@register_opcode_handler("POP_JUMP_IF_FALSE", min_ver=(3, 12)) def _pop_jump_if_false_handler( inst: dis.Instruction, /, stack: InterpreterStack, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -5108,11 +5323,13 @@ def _pop_jump_if_false_handler( cnd: bool = ures if not cnd: + # note that inst.arg is relative for 3.12 and absolute for 3.10 return inst.arg return None @register_opcode_handler("POP_JUMP_IF_TRUE", max_ver=(3, 10)) +@register_opcode_handler("POP_JUMP_IF_TRUE", min_ver=(3, 12)) def _pop_jump_if_true_handler( inst: dis.Instruction, /, stack: InterpreterStack, **kwargs ) -> int | None | INTERPRETER_SIGNALS: @@ -5137,7 +5354,36 @@ def impl(): cnd = tmp if cnd: + # note that inst.arg is relative for 3.12 and absolute for 3.10 + return inst.arg + return None + + +# https://docs.python.org/3.12/library/dis.html#opcode-POP_JUMP_IF_NONE +@register_opcode_handler("POP_JUMP_IF_NONE", min_ver=(3, 12)) +def _pop_jump_if_none_handler(inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs) -> int | None: + assert isinstance(inst.arg, int) + + tos = stack.pop() + if tos is None: + # in 3.12 return is relative to the next opcode + return inst.arg + + return None + + +# https://docs.python.org/3.12/library/dis.html#opcode-POP_JUMP_IF_NOT_NONE +@register_opcode_handler("POP_JUMP_IF_NOT_NONE", min_ver=(3, 12)) +def _pop_jump_if_not_none_handler( + inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs +) -> int | None: + assert isinstance(inst.arg, int) + + tos = stack.pop() + if tos is not None: + # in 3.12 return is relative to the next opcode return inst.arg + return None @@ -5210,23 +5456,27 @@ def do_raise(exc: Any = Py_NULL(), cause: Any = Py_NULL()) -> Literal[INTERPRETE return INTERPRETER_SIGNALS.EXCEPTION_RAISED -# https://docs.python.org/3.11/library/dis.html#opcode-PRINT_EXPR -@register_opcode_handler("PRINT_EXPR") -def _print_expr_handler( - inst: dis.Instruction, /, stack: InterpreterStack, frame: InterpreterFrame, **kwargs -) -> None | INTERPRETER_SIGNALS: - def impl(tos): +def _print_intrinsic(expr): + def impl(expr): # NOTE: There is no other way to obtain the display hook, other # than writing a C extension, so we mangle. # NOTE: The display hook's return value is ignored by cpython. # NOTE: By default, type(sys.__displayhook__) is . from sys import displayhook as __thunder_sys_displayhook - __thunder_sys_displayhook(tos) - return None + __thunder_sys_displayhook(expr) - tos = stack.pop() - val = _interpret_call_with_unwrapping(impl, tos) + return _interpret_call(impl, expr) + + +# https://docs.python.org/3.11/library/dis.html#opcode-PRINT_EXPR +@register_opcode_handler("PRINT_EXPR", max_ver=(3, 11)) +def _print_expr_handler( + inst: dis.Instruction, /, stack: InterpreterStack, frame: InterpreterFrame, **kwargs +) -> None | INTERPRETER_SIGNALS: + expr = stack.pop_wrapped() + + val = _print_intrinsic(expr) if val is INTERPRETER_SIGNALS.EXCEPTION_RAISED: return val return None @@ -5309,6 +5559,20 @@ def _reraise_handler_3_11( return INTERPRETER_SIGNALS.EXCEPTION_RAISED +# https://docs.python.org/3.12/library/dis.html#opcode-RETURN_CONST +@register_opcode_handler("RETURN_CONST", min_ver=(3, 12)) +def _return_const_handler( + inst: dis.Instruction, /, co: CodeType, stack: InterpreterStack, **kwargs +) -> int | None | INTERPRETER_SIGNALS: + assert type(inst.arg) is int + + constant = co.co_consts[inst.arg] + constant = wrap_const(constant) + constant = const_callback(constant) + stack.append(constant) + return INTERPRETER_SIGNALS.RETURN_VALUE + + # https://docs.python.org/3.10/library/dis.html#opcode-RETURN_VALUE @register_opcode_handler("RETURN_VALUE") def _return_value_handler(inst: dis.Instruction, /, **kwargs) -> int | None | INTERPRETER_SIGNALS: @@ -5592,6 +5856,20 @@ def impl(names_dict, name, value): return check_signal(_interpret_call(impl, frame.names, wrap_const(name), tos)) +# https://docs.python.org/3.12/library/dis.html#opcode-STORE_SLICE +@register_opcode_handler("STORE_SLICE", min_ver=(3, 12)) +def _store_slice_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: + end = stack.pop_wrapped() + start = stack.pop_wrapped() + container = stack.pop_wrapped() + values = stack.pop_wrapped() + + def impl(container, start, end, values): + return container.__setitem__(slice(start, end), values) + + return _interpret_call_with_unwrapping(impl, container, start, end, values) + + # https://docs.python.org/3.10/library/dis.html#opcode-STORE_SUBSCR @register_opcode_handler("STORE_SUBSCR") def _store_subscr_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: @@ -5650,12 +5928,8 @@ def impl(): return check_and_append(stack, _interpret_call_with_unwrapping(impl)) -# https://docs.python.org/3.10/library/dis.html#opcode-UNARY_POSITIVE -@register_opcode_handler("UNARY_POSITIVE") -def _unary_positive_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: - tos = stack.pop() - - def impl(): +def _unary_positive_intrinsic(tos): + def impl(tos): if hasattr(tos, "__pos__"): result = tos.__pos__() if result is not NotImplemented: @@ -5663,7 +5937,15 @@ def impl(): raise TypeError(f"bad operand type for unary +: '{type(tos).__name__}'") - return check_and_append(stack, _interpret_call_with_unwrapping(impl)) + return _interpret_call_with_unwrapping(impl, tos) + + +# https://docs.python.org/3.10/library/dis.html#opcode-UNARY_POSITIVE +@register_opcode_handler("UNARY_POSITIVE", max_ver=(3, 11)) +def _unary_positive_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS: + tos = stack.pop() + res = _unary_positive_intrinsic(tos) + return check_and_append(stack, res) # https://docs.python.org/3.10/library/dis.html#opcode-UNPACK_EX @@ -5801,10 +6083,16 @@ def impl(): if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED: runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx() if isinstance(runtimectx.curexc, StopIteration): - stack.pop() # remove generator - stack.append(runtimectx.curexc.value) + retval = runtimectx.curexc.value runtimectx.curexc = None - return inst_ptr + inst.arg + 1 + if sys.version_info < (3, 12): + stack.pop() # remove generator + stack.append(retval) + return inst_ptr + inst.arg + 1 + else: + # Python 3.12 keeps the generator, returns relative jump + stack.append(retval) + return inst.arg else: return res # propagate exception @@ -6393,7 +6681,7 @@ def _setup_frame_and_run_python_function( for i, (name, value) in enumerate(zip(code.co_freevars, closure)): local = freevar_callback(name, value, fn=wrapped_fn, idx=i) localsplus.append(local) - elif (3, 11) <= sys.version_info < (3, 12): + elif (3, 11) <= sys.version_info < (3, 13): assert len(code.co_varnames) == code.co_nlocals for n in code.co_varnames: local = locals_dict.get(n, Py_NULL()) @@ -6461,6 +6749,9 @@ def _run_frame( insts: tuple[dis.Instruction, ...] = tuple(dis.get_instructions(frame.code)) # adjustments for "hidden" instructions (EXTENDED_ARGS, CACHE, ...) inst_ptr_to_idx = {inst.offset // 2: idx for idx, inst in enumerate(insts)} + idx_to_next_inst_ptr = [inst.offset // 2 for inst in insts[1:]] + idx_to_next_inst_ptr.append(len(frame.code.co_code)) + max_inst_ptr = max(inst_ptr_to_idx.keys()) while True: # we might have jumped or advanced to a "hidden" instruction such as cache, @@ -6473,7 +6764,6 @@ def _run_frame( assert frame.inst_ptr <= max_inst_ptr frame.inst_ptr += 1 inst: dis.Instruction = insts[inst_ptr_to_idx[frame.inst_ptr]] - # Updates the stack frame to the current position # TODO maybe also have inst_ptr? frame.nexti(inst) @@ -6517,7 +6807,13 @@ def _run_frame( frame.interpreter_stack.append(current_exception) current_exception = None skip_stack_effect_check = True - interpretation_result = et_handler // 2 + if sys.version_info >= (3, 12): + # yeah, this is *really* ugly to make interpretation_result stay relative... maybe introduce another signal + interpretation_result = ( + et_handler // 2 - idx_to_next_inst_ptr[inst_ptr_to_idx[frame.inst_ptr]] + ) + else: + interpretation_result = et_handler // 2 else: # This is Python 3.10-style unwinding skip_stack_effect_check = True # or only do this in ifs below? @@ -6612,7 +6908,11 @@ def _run_frame( frame.inst_ptr += 1 else: assert isinstance(interpretation_result, int), interpretation_result - frame.inst_ptr = interpretation_result + if sys.version_info >= (3, 12): + # in Python >= 3.12 all jumps are relative to the next instruction + frame.inst_ptr = idx_to_next_inst_ptr[inst_ptr_to_idx[frame.inst_ptr]] + interpretation_result + else: + frame.inst_ptr = interpretation_result if not skip_stack_effect_check: # the exception handling will change the stack wildly # Verifies the handler had the expected stack effect (delta on stack sie) diff --git a/thunder/tests/test_interpreter.py b/thunder/tests/test_interpreter.py index 445bea756f..060d5123a5 100644 --- a/thunder/tests/test_interpreter.py +++ b/thunder/tests/test_interpreter.py @@ -2119,7 +2119,13 @@ def test_list_to_tuple(jit): def ltt(): return (*[1, 2, 3],) - assert any(i.opname == "LIST_TO_TUPLE" for i in dis.get_instructions(ltt)) + if sys.version_info >= (3, 12): + assert any( + (i.opname == "CALL_INTRINSIC_1" and i.argrepr == "INTRINSIC_LIST_TO_TUPLE") + for i in dis.get_instructions(ltt) + ) + else: + assert any(i.opname == "LIST_TO_TUPLE" for i in dis.get_instructions(ltt)) assert jit(ltt)() == ltt() @@ -2744,6 +2750,10 @@ def foo(x): jfn() +@pytest.mark.skipif( + sys.version_info >= (3, 12), + reason="Python 3.12 code.InteractiveInterpreter().runsource does not use the displayhook as before", +) def test_displayhook(jit): from contextlib import redirect_stdout import io