Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

:Cannot infer concrete type of torch.nn.Module #36370

Open
Jereshea opened this issue Feb 24, 2025 · 2 comments
Open

:Cannot infer concrete type of torch.nn.Module #36370

Jereshea opened this issue Feb 24, 2025 · 2 comments

Comments

@Jereshea
Copy link

Jereshea commented Feb 24, 2025

Environment Info:

  • Python Version : 3.12.2
  • PyTorch Version : 2.5.1+cpu
  • Transformers Version : 4.49.0
  • Model_name : "meta-llama/Llama-3.1-8B-Instruct"

I am trying to run inference for the model "meta-llama/Llama-3.1-8B-Instruct" using torch.jit.trace(). Below is my code snippet.

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    trace_model = torch.jit.trace(model, input_ids, check_trace=False, strict=False)
    traced_model = torch.jit.freeze(trace_model)
    output = traced_model.generate(input_ids)

However, I am encountering errors when I try to pass the traced_model to model.generate(). Is there any possibility to overcome this error

RuntimeError: Tracer cannot infer type of (tensor([[[-3.1250,  1.3750,  7.2500,  ...,  1.7422,  1.7422,  1.7422],
         [ 5.8438,  7.0625,  5.8438,  ..., -4.0938, -4.0938, -4.0938],
         [ 1.1875,  0.6328,  1.5156,  ..., -8.6250, -8.6250, -8.6250],
         ...,
         [ 5.1250,  1.3047, -1.6094,  ..., -2.7500, -2.7500, -2.7500],
         [ 5.5625,  5.2188,  0.9844,  ..., -3.6094, -3.6094, -3.6094],
         [ 1.0312, -0.7852, -1.1406,  ..., -3.6562, -3.6562, -3.6562]]],
       dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), DynamicCache())
:Cannot infer concrete type of torch.nn.Module

But when I try in different version of transformers(transformers=4.43.2), I am getting different error as below

AttributeError: 'RecursiveScriptModule' object has no attribute 'generate'

Is there a fix for this issue in any version of Transformers, or an alternative approach to enable generate with a traced model?

@Rocketknight1
Copy link
Member

I don't think generate() is intended to work automatically with torch.jit.trace(), is it? cc @gante - do we have a recommended way to trace/export generation loops?

@Jereshea
Copy link
Author

Is any alternate solution recommended?

In the meantime, I performed inference using the IPEX backend. Below is the snippet I used:

    model = ipex.llm.optimize(model, dtype=amp_dtype, inplace=True, **deployment_mode=True**)
    output = model.generate(inputs)

When deployment_mode=True is set, it internally uses torch.jit.trace (Reference). To use the traced model for generate, some modifications have been made in IPEX (Reference).

Do we have any similar fix or approach like this in native PyTorch?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants