diff --git a/thunder/tests/test_transforms.py b/thunder/tests/test_transforms.py index 1f7ea58562..2418eeae6a 100644 --- a/thunder/tests/test_transforms.py +++ b/thunder/tests/test_transforms.py @@ -624,3 +624,24 @@ def fn(): assert_close(jfn(), fn()) assert any(("CUDAGraph" in bsym.sym.name) for bsym in thunder.last_traces(jfn)[-1].bound_symbols) + + +def test_disable_params_and_buffer_check(): + from thunder.tests.litgpt_model import Config + from litgpt.model import GPT + from thunder.transforms.extraction_only_prologue_transform import ExtractionOnlyPrologueTransform + + model = GPT(Config.from_name("llama1-like", n_layer=1)) + x = torch.randint(model.max_seq_length, (2, 5)) + cmodel = thunder.jit(model, transforms=[ExtractionOnlyPrologueTransform()]) + _ = cmodel(x) + prologue_trc = thunder.last_prologue_traces(cmodel)[-1] + + check_bsyms = tuple( + filter( + lambda bsym: bsym.sym.id == thunder.executors.pythonex.check_tensor_shape_and_metadata.id, + prologue_trc.bound_symbols, + ) + ) + + assert len(check_bsyms) == 1 # We only have the check for input. diff --git a/thunder/transforms/extraction_only_prologue_transform.py b/thunder/transforms/extraction_only_prologue_transform.py new file mode 100644 index 0000000000..7f335ef33c --- /dev/null +++ b/thunder/transforms/extraction_only_prologue_transform.py @@ -0,0 +1,26 @@ +import thunder +from thunder.core.trace import from_trace +from thunder.core.proxies import ProxyTag + + +class ExtractionOnlyPrologueTransform(thunder.Transform): + def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs): + new_prologue_trace = from_trace(prologue_trace) + new_bsyms = [] + + for bsym in prologue_trace.bound_symbols: + # NOTE - We assume TensorProxy's tagged with `STATIC_MEMORY_LOCATION` to + # be Parameters or Buffer. It should be safe to disable check for + # tensors we deem to be static. + if ( + bsym.sym.id == thunder.prims.PrimIDs.CHECK_TENSOR_SHAPE_AND_METADATA + and ProxyTag.STATIC_MEMORY_LOCATION in bsym.args[0].tags + ): + continue + + new_bsyms.append(bsym) + + new_prologue_trace.bound_symbols = new_bsyms + + new_prologue_trace.set_provenance("Extraction only prologue pass") + return new_prologue_trace, computation_trace, epilogue_trace