Skip to content

Commit 30e4aa1

Browse files
Implementing zip lookaside for interpreter (#1259)
Co-authored-by: Rany Kamel <rany.kamel47@gmail.com>
1 parent 617fe8e commit 30e4aa1

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

thunder/core/interpreter.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,47 @@ def impl(obj, start):
14131413
return _interpret_call(impl, obj, wrap_const(start))
14141414

14151415

1416+
def _zip_lookaside(*obj: Iterable, strict=False):
1417+
1418+
if not obj:
1419+
return
1420+
1421+
def zip(*obj, strict=False):
1422+
# zip('ABCD', 'xy') --> Ax By
1423+
sentinel = object()
1424+
iterators = [iter(it) for it in obj]
1425+
while iterators:
1426+
result = []
1427+
break_loop = False
1428+
for it in iterators:
1429+
elem = next(it, sentinel)
1430+
if elem is sentinel:
1431+
if not strict:
1432+
return
1433+
else:
1434+
break_loop = True
1435+
break
1436+
result.append(elem)
1437+
1438+
if break_loop:
1439+
break
1440+
1441+
yield tuple(result)
1442+
if result:
1443+
i = len(result)
1444+
plural = " " if i == 1 else "s 1-"
1445+
msg = f"zip() argument {i+1} is shorter than argument{plural}{i}"
1446+
raise ValueError(msg)
1447+
sentinel = object()
1448+
for i, iterator in enumerate(iterators[1:], 1):
1449+
if next(iterator, sentinel) is not sentinel:
1450+
plural = " " if i == 1 else "s 1-"
1451+
msg = f"zip() argument {i+1} is longer than argument{plural}{i}"
1452+
raise ValueError(msg)
1453+
1454+
return _interpret_call(zip, *obj, strict=wrap_const(strict))
1455+
1456+
14161457
@interpreter_needs_wrap
14171458
def eval_lookaside(
14181459
source: str | bytes | bytearray | CodeType, # A python expression
@@ -2743,6 +2784,7 @@ def _type_call_lookaside(wrapped_typ, *args, **kwargs):
27432784
any: _any_lookaside,
27442785
bool: _bool_lookaside,
27452786
enumerate: _enumerate_lookaside,
2787+
zip: _zip_lookaside,
27462788
exec: exec_lookaside,
27472789
eval: eval_lookaside,
27482790
getattr: _getattr_lookaside,

thunder/tests/test_interpreter.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,42 @@ def fn():
16911691
jit(fn)()
16921692

16931693

1694+
def test_zip_lookaside(jit):
1695+
import re
1696+
1697+
jitting = False
1698+
1699+
def foo(*a, strict=False):
1700+
return list(zip(*a, strict=strict))
1701+
1702+
jfoo = jit(foo)
1703+
jitting = False
1704+
1705+
res1 = foo([1, 2, 3], [4, 5, 6])
1706+
res2 = foo([1, 2, 3], [4, 5, 6], [7, 8, 9])
1707+
res3 = foo([1, 2], [4, 5, 6])
1708+
res4 = foo("abc", "xyz")
1709+
# , match="zip() argument 2 is longer than argument 1"
1710+
1711+
with pytest.raises(ValueError, match=re.escape("zip() argument 2 is longer than argument 1")):
1712+
res5 = foo([1, 2], [4, 5, 6], strict=True)
1713+
1714+
jitting = True
1715+
jres1 = jfoo([1, 2, 3], [4, 5, 6])
1716+
jres2 = jfoo([1, 2, 3], [4, 5, 6], [7, 8, 9])
1717+
jres3 = jfoo([1, 2], [4, 5, 6])
1718+
jres4 = jfoo("abc", "xyz")
1719+
1720+
# , match=" zip() argument 2 is longer than argument 1"
1721+
with pytest.raises(ValueError, match=re.escape("zip() argument 2 is longer than argument 1")):
1722+
jres5 = jfoo([1, 2], [4, 5, 6], strict=True)
1723+
1724+
assert res1 == jres1
1725+
assert res2 == jres2
1726+
assert res3 == jres3
1727+
assert res4 == jres4
1728+
1729+
16941730
def test_enumerate_lookaside(jit):
16951731
jitting = False
16961732

0 commit comments

Comments
 (0)