Skip to content

Commit 18899a8

Browse files
authored
Allowing static constraint in torch/__init__.py (#613)
1 parent f334d39 commit 18899a8

File tree

5 files changed

+38
-0
lines changed

5 files changed

+38
-0
lines changed

thunder/core/prims.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ class PrimIDs(Enum):
261261
# Memory access methods
262262
ITEM = auto()
263263
COPY_ = auto()
264+
#
265+
SINK = auto()
264266

265267

266268
class OpTags(Enum):
@@ -3888,3 +3890,11 @@ def copy__meta(
38883890

38893891

38903892
copy_ = make_prim(PrimIDs.COPY_, "copy_", meta=copy__meta, tags=(OpTags.DONT_DCE,))
3893+
3894+
3895+
def sink_meta(*args, **kwargs):
3896+
return
3897+
3898+
3899+
# TODO do we want another tag to remove this after prologue is constructed?
3900+
sink = make_prim(PrimIDs.SINK, "sink", meta=sink_meta, tags=(OpTags.DONT_DCE,))

thunder/core/proxies.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,7 @@ def known_value(self) -> bool:
612612

613613
def make_static_constrained(self):
614614
baseutils.check(self.constraint != CONSTRAINT.DYNAMIC, lambda: f"dynamic NumberProxy cannot be made static")
615+
baseutils.check(self.value is not None, lambda: f"static NumberProxy needs to have value")
615616
self.constraint = CONSTRAINT.STATIC
616617

617618
def make_constrainable(self):

thunder/executors/pythonex.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,5 +365,13 @@ def _elementwise_binary_checker(a: NumberLike | TensorProxy, b: NumberLike | Ten
365365
ex.register_implementation(prims.sub, sub, checker=_elementwise_binary_checker)
366366
ex.register_implementation(prims.div, div, checker=_elementwise_binary_checker)
367367

368+
369+
def _sink(*args, **kwargs):
370+
return
371+
372+
373+
sink = ex.register_operator("sink", like=prims.sink, fn=_sink)
374+
ex.register_implementation(prims.sink, sink, checker=_always_executable)
375+
368376
# TODO: Restore truediv once we find it...
369377
# ex.register_implementation(prims.truediv, truediv, checker=_elementwise_binary_checker)

thunder/tests/test_jit_general.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,18 @@ def bar(t):
955955
jbar = thunder.jit(bar, cache="symbolic values")
956956
t = torch.randn(4, device="cpu")
957957
jbar(t)
958+
959+
960+
def test_cache_symbolic_values_torch_device():
961+
if not torch.cuda.is_available():
962+
pytest.skip("CUDA not available")
963+
964+
def foo(dev, idx):
965+
# NOTE dtype needs to be explicit, see issue: https://github.com/Lightning-AI/lightning-thunder/issues/621
966+
return torch.ones(1, device=torch.device(dev, idx), dtype=torch.float32)
967+
968+
jfoo = thunder.jit(foo, cache="symbolic values")
969+
expected = foo("cuda", 0)
970+
actual = jfoo("cuda", 0)
971+
972+
assert_close(expected, actual)

thunder/torch/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4906,6 +4906,10 @@ def torch_device(device_or_str: DeviceLike, /, index: int | None = None) -> devi
49064906
not has_device_idx,
49074907
lambda: f"device string must not include an index because index was passed explicitly: {device_or_str}",
49084908
)
4909+
if isinstance(index, NumberProxy):
4910+
index.make_static_constrained()
4911+
prims.sink(index)
4912+
index = index.value
49094913

49104914
return devices.Device(device_or_str, index)
49114915

0 commit comments

Comments
 (0)