Skip to content

Commit 313014b

Browse files
committed
updates
1 parent 54d7b9b commit 313014b

File tree

2 files changed

+94
-12
lines changed

2 files changed

+94
-12
lines changed

thunder/core/interpreter.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,6 +1239,9 @@ def __call__(self, fn: Callable) -> Callable:
12391239
):
12401240
assert self.name not in _default_opcode_handler_map, self.name
12411241
assert self.name in dis.opmap, self.name
1242+
assert (
1243+
self.name.lower() in fn.__name__
1244+
), f"opcode handler name mismatch {self.name.lower()} vs. {fn.__name__}"
12421245
_default_opcode_handler_map[self.name] = fn
12431246
return fn
12441247
return _default_opcode_handler_map.get(self.name, fn)
@@ -3444,6 +3447,24 @@ def _inplace_xor_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kw
34443447
return _binary_op_helper(stack, BINARY_OP.IXOR)
34453448

34463449

3450+
# https://docs.python.org/3.12/library/dis.html#opcode-BINARY_SLICE
3451+
@register_opcode_handler("BINARY_SLICE")
3452+
def _binary_slice_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
3453+
end = stack.pop_wrapped()
3454+
start = stack.pop_wrapped()
3455+
container = stack.pop_wrapped()
3456+
3457+
def impl(container, start, end):
3458+
return container.__getitem__(slice(start, end))
3459+
3460+
res = _interpret_call(impl, container, start, end)
3461+
3462+
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
3463+
return res
3464+
3465+
return check_and_append(stack, res)
3466+
3467+
34473468
# https://docs.python.org/3.10/library/dis.html#opcode-BINARY_SUBSCR
34483469
@register_opcode_handler("BINARY_SUBSCR")
34493470
def _binary_subscr_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
@@ -4035,13 +4056,19 @@ def _end_async_for_handler_3_11(
40354056
return INTERPRETER_SIGNALS.EXCEPTION_RAISED
40364057

40374058

4038-
# https://docs.python.org/3.11/library/dis.html#opcode-END_ASYNC_FOR
4059+
# https://docs.python.org/3.11/library/dis.html#opcode-END_FOR
40394060
@register_opcode_handler("END_FOR", min_ver=(3, 12))
40404061
def _end_for_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
40414062
stack.pop_wrapped()
40424063
stack.pop_wrapped()
40434064

40444065

4066+
# https://docs.python.org/3.12/library/dis.html#opcode-END_SEND
4067+
@register_opcode_handler("END_SEND", min_ver=(3, 12))
4068+
def _end_send_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
4069+
stack.pop_wrapped()
4070+
4071+
40454072
# https://docs.python.org/3.10/library/dis.html#opcode-EXTENDED_ARG
40464073
@register_opcode_handler("EXTENDED_ARG")
40474074
def _extended_arg_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
@@ -4507,7 +4534,7 @@ def _list_to_tuple_handler(inst: dis.Instruction, /, stack: InterpreterStack, **
45074534

45084535
# https://docs.python.org/3.13/library/dis.html#opcode-LOAD_ASSERTION_ERROR
45094536
@register_opcode_handler("LOAD_ASSERTION_ERROR")
4510-
def _load_assertion_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
4537+
def _load_assertion_error_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None:
45114538
stack.append(wrap_const(AssertionError))
45124539

45134540

@@ -4613,8 +4640,11 @@ def _load_deref_handler(
46134640

46144641

46154642
# https://docs.python.org/3.10/library/dis.html#opcode-LOAD_FAST
4643+
# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_FAST_CHECK
4644+
# LOAD_FAST for Python <3.12 is LOAD_FAST_CHECK
4645+
@register_opcode_handler("LOAD_FAST_CHECK", min_ver=(3, 12))
46164646
@register_opcode_handler("LOAD_FAST")
4617-
def _load_fast_handler(
4647+
def _load_fast_check_handler(
46184648
inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs
46194649
) -> None | INTERPRETER_SIGNALS:
46204650
assert isinstance(inst.arg, int)
@@ -4637,6 +4667,34 @@ def _load_fast_handler(
46374667
return check_and_append(stack, val)
46384668

46394669

4670+
# https://docs.python.org/3.12/library/dis.html#opcode-LOAD_FAST_AND_CLEAR
4671+
@register_opcode_handler("LOAD_FAST_AND_CLEAR", min_ver=(3, 12))
4672+
def _load_fast_and_clear_handler(
4673+
inst: dis.Instruction, /, stack: InterpreterStack, co: CodeType, frame: InterpreterFrame, **kwargs
4674+
) -> None | INTERPRETER_SIGNALS:
4675+
assert isinstance(inst.arg, int)
4676+
var_num: int = inst.arg
4677+
assert var_num >= 0 and var_num < len(frame.localsplus)
4678+
4679+
val: Any = frame.localsplus[var_num]
4680+
name: str = frame.get_localsplus_name(var_num)
4681+
4682+
# empty local variable slots are initialized to Py_NULL(), in this
4683+
# case we push Py_NULL (but wrapped)
4684+
if isinstance(val, Py_NULL):
4685+
val = wrap_const(Py_NULL())
4686+
4687+
ctx: InterpreterCompileCtx = get_interpretercompilectx()
4688+
if ctx._with_provenance_tracking:
4689+
assert isinstance(val, WrappedValue), f"unexpected value of type {type(val)}, {val}, {inst}"
4690+
4691+
val = load_fast_callback(val, name)
4692+
# clear the local variable
4693+
frame.localsplus[var_num] = Py_NULL()
4694+
4695+
return check_and_append(stack, val)
4696+
4697+
46404698
# https://docs.python.org/3.10/library/dis.html#opcode-LOAD_GLOBAL
46414699
@register_opcode_handler("LOAD_GLOBAL")
46424700
def _load_global_handler(
@@ -5193,9 +5251,7 @@ def impl():
51935251

51945252
# https://docs.python.org/3.12/library/dis.html#opcode-POP_JUMP_IF_NONE
51955253
@register_opcode_handler("POP_JUMP_IF_NONE", min_ver=(3, 12))
5196-
def _pop_jump_forward_if_none_handler(
5197-
inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs
5198-
) -> int | None:
5254+
def _pop_jump_if_none_handler(inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs) -> int | None:
51995255
assert isinstance(inst.arg, int)
52005256

52015257
tos = stack.pop()
@@ -5208,7 +5264,7 @@ def _pop_jump_forward_if_none_handler(
52085264

52095265
# https://docs.python.org/3.12/library/dis.html#opcode-POP_JUMP_IF_NOT_NONE
52105266
@register_opcode_handler("POP_JUMP_IF_NOT_NONE", min_ver=(3, 12))
5211-
def _pop_jump_forward_if_none_handler(
5267+
def _pop_jump_if_not_none_handler(
52125268
inst: dis.Instruction, /, stack: InterpreterStack, inst_ptr: int, **kwargs
52135269
) -> int | None:
52145270
assert isinstance(inst.arg, int)
@@ -5391,7 +5447,7 @@ def _reraise_handler_3_11(
53915447

53925448
# https://docs.python.org/3.12/library/dis.html#opcode-RETURN_CONST
53935449
@register_opcode_handler("RETURN_CONST", min_ver=(3, 12))
5394-
def _return_value_handler(
5450+
def _return_const_handler(
53955451
inst: dis.Instruction, /, co: CodeType, stack: InterpreterStack, **kwargs
53965452
) -> int | None | INTERPRETER_SIGNALS:
53975453
assert type(inst.arg) is int
@@ -5686,6 +5742,20 @@ def impl(names_dict, name, value):
56865742
return check_signal(_interpret_call(impl, frame.names, wrap_const(name), tos))
56875743

56885744

5745+
# https://docs.python.org/3.12/library/dis.html#opcode-STORE_SLICE
5746+
@register_opcode_handler("STORE_SLICE")
5747+
def _store_slice_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
5748+
end = stack.pop_wrapped()
5749+
start = stack.pop_wrapped()
5750+
container = stack.pop_wrapped()
5751+
values = stack.pop_wrapped()
5752+
5753+
def impl(container, start, end, values):
5754+
return container.__setitem__(slice(start, end), values)
5755+
5756+
return _interpret_call_with_unwrapping(impl, container, start, end, values)
5757+
5758+
56895759
# https://docs.python.org/3.10/library/dis.html#opcode-STORE_SUBSCR
56905760
@register_opcode_handler("STORE_SUBSCR")
56915761
def _store_subscr_handler(inst: dis.Instruction, /, stack: InterpreterStack, **kwargs) -> None | INTERPRETER_SIGNALS:
@@ -5895,10 +5965,16 @@ def impl():
58955965
if res is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
58965966
runtimectx: InterpreterRuntimeCtx = get_interpreterruntimectx()
58975967
if isinstance(runtimectx.curexc, StopIteration):
5898-
stack.pop() # remove generator
5899-
stack.append(runtimectx.curexc.value)
5968+
retval = runtimectx.curexc.value
59005969
runtimectx.curexc = None
5901-
return inst_ptr + inst.arg + 1
5970+
if sys.version_info < (3, 12):
5971+
stack.pop() # remove generator
5972+
stack.append(retval)
5973+
return inst_ptr + inst.arg + 1
5974+
else:
5975+
# Python 3.12 keeps the generator, returns relative jump
5976+
stack.append(retval)
5977+
return inst.arg
59025978
else:
59035979
return res # propagate exception
59045980

thunder/tests/test_interpreter.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2119,7 +2119,13 @@ def test_list_to_tuple(jit):
21192119
def ltt():
21202120
return (*[1, 2, 3],)
21212121

2122-
assert any(i.opname == "LIST_TO_TUPLE" for i in dis.get_instructions(ltt))
2122+
if sys.version_info >= (3, 12):
2123+
assert any(
2124+
(i.opname == "CALL_INTRINSIC_1" and i.argrepr == "INTRINSIC_LIST_TO_TUPLE")
2125+
for i in dis.get_instructions(ltt)
2126+
)
2127+
else:
2128+
assert any(i.opname == "LIST_TO_TUPLE" for i in dis.get_instructions(ltt))
21232129
assert jit(ltt)() == ltt()
21242130

21252131

0 commit comments

Comments
 (0)