Skip to content

Commit 3f68b34

Browse files
kshitij12345riccardofelluga
authored andcommitted
Add a transform to remove param and buffer shape check from prologue (#1564)
1 parent e8ae8c6 commit 3f68b34

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

thunder/tests/test_transforms.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,3 +624,24 @@ def fn():
624624
assert_close(jfn(), fn())
625625

626626
assert any(("CUDAGraph" in bsym.sym.name) for bsym in thunder.last_traces(jfn)[-1].bound_symbols)
627+
628+
629+
def test_disable_params_and_buffer_check():
630+
from thunder.tests.litgpt_model import Config
631+
from litgpt.model import GPT
632+
from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform
633+
634+
model = GPT(Config.from_name("llama1-like", n_layer=1))
635+
x = torch.randint(model.max_seq_length, (2, 5))
636+
cmodel = thunder.jit(model, transforms=[ExtractionOnlyPrologueTransform()])
637+
_ = cmodel(x)
638+
prologue_trc = thunder.last_prologue_traces(cmodel)[-1]
639+
640+
check_bsyms = tuple(
641+
filter(
642+
lambda bsym: bsym.sym.id == thunder.executors.pythonex.check_tensor_shape_and_metadata.id,
643+
prologue_trc.bound_symbols,
644+
)
645+
)
646+
647+
assert len(check_bsyms) == 1 # We only have the check for input.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import thunder
2+
from thunder.core.trace import from_trace
3+
from thunder.core.proxies import ProxyTag
4+
5+
6+
class ExtractionOnlyPrologueTransform(thunder.Transform):
7+
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
8+
new_prologue_trace = from_trace(prologue_trace)
9+
new_bsyms = []
10+
11+
for bsym in prologue_trace.bound_symbols:
12+
# NOTE - We assume TensorProxy's tagged with `STATIC_MEMORY_LOCATION` to
13+
# be Parameters or Buffer. It should be safe to disable check for
14+
# tensors we deem to be static.
15+
if (
16+
bsym.sym.id == thunder.prims.PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA
17+
and ProxyTag.STATIC_MEMORY_LOCATION in bsym.args[0].tags
18+
):
19+
continue
20+
21+
new_bsyms.append(bsym)
22+
23+
new_prologue_trace.bound_symbols = new_bsyms
24+
25+
new_prologue_trace.set_provenance("Extraction only prologue pass")
26+
return new_prologue_trace, computation_trace, epilogue_trace

0 commit comments

Comments
 (0)