Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError from process_recorded_modifications #1405

Open
IvanYashchuk opened this issue Nov 7, 2024 · 4 comments
Open

AssertionError from process_recorded_modifications #1405

IvanYashchuk opened this issue Nov 7, 2024 · 4 comments
Assignees
Labels
huggingface For supporting HF models jit

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 7, 2024

🐛 Bug

I get the following assertion error from Thunder JIT:

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1731, in thunder_general_jit(fn, args, kwargs, record_history, sharp_edges, ad_hoc_executor)
   1729         prims.python_return(result)
   1730         computation_trace.set_current_source_location(None, None)
-> 1731         process_recorded_modifications(ctx, epilogue_trace)
   1732         last_interpreter_log = jfn._last_interpreter_log
   1734 pro_to_comp, computation_intermediates = get_computation_inputs_and_intermediates(computation_trace)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1610, in process_recorded_modifications(ctx, epilogue_trace)
   1608 if inst == PseudoInst.STORE_SUBSCR:
   1609     (value,) = args
-> 1610     assert isinstance(value.value, Proxy)
   1612     assert modified_object.provenance.inst is PseudoInst.LOAD_ATTR
   1613     assert modified_object.provenance.inputs[1].inst is PseudoInst.CONSTANT

AssertionError:

using the following script:

import torch
import thunder
from transformers import Qwen2Config, Qwen2ForCausalLM

# https://huggingface.co/Qwen/Qwen2.5-7B-Instruct/blob/main/config.json
configuration = Qwen2Config(
    # Qwen2.5-7B-Instruct uses Grouped-Query Attention, while the default
    # config uses Multi-Head Attention
    num_attention_heads=28,
    num_key_value_heads=4,
    # Scaled down for testing
    hidden_size=56,
    vocab_size=4096,
)
configuration.num_hidden_layers = 1
with torch.device("cuda"):
    model = Qwen2ForCausalLM(configuration).to(torch.bfloat16)

compiled_model = thunder.jit(model)
input_ids = torch.randint(0, configuration.vocab_size, (1, configuration.max_position_embeddings), device="cuda")
compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)

The transformers version is 4.45.2

@IvanYashchuk IvanYashchuk added jit huggingface For supporting HF models labels Nov 7, 2024
@t-vi
Copy link
Collaborator

t-vi commented Nov 11, 2024

Thank you @IvanYashchuk

@t-vi t-vi self-assigned this Nov 11, 2024
@t-vi
Copy link
Collaborator

t-vi commented Nov 11, 2024

@IvanYashchuk what transformers version are you using? I'm getting an indexing error with 4.43-ish and a different assertion error (but needs fixing, too) with 4.46.2.

@t-vi
Copy link
Collaborator

t-vi commented Nov 11, 2024

So with 4.46.2 and the following lookaside, things seem to work:

from transformers.modeling_utils import PreTrainedModel

@thunder.core.jit_ext.register_general_jit_lookaside(PreTrainedModel.loss_function.fget)
@thunder.core.jit_ext.interpreter_needs_wrap
def fn(*args, **kwargs):
    return PreTrainedModel.loss_function.fget(*args, **kwargs)

The loss_function property uses the Python re module to parse the loss function config, I wonder if we should allow marking modules as "treat everything here as opaque". Dangerous tool, but I think it might be more reasonable than relying on the internals of transformers.
@lantiga for UX thoughts

@IvanYashchuk
Copy link
Collaborator Author

I get this error with transformers version of 4.45.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
huggingface For supporting HF models jit
Projects
None yet
Development

No branches or pull requests

2 participants