Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to provide overrides needed for model compatibility #816

Open
t-vi opened this issue Jul 22, 2024 · 2 comments
Open

how to provide overrides needed for model compatibility #816

t-vi opened this issue Jul 22, 2024 · 2 comments
Assignees
Labels
design This is a largish feature / design interpreter program-coverage Requests for model and program coverage

Comments

@t-vi
Copy link
Collaborator

t-vi commented Jul 22, 2024

While transformers BERT is not yet fully working, we are getting closer.
However, there is the need to disable some data-dependent control flow (typically: checks) to get it to work.
Transformers itself hides some but not all behind a check for compiling, e.g.

https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/modeling_attn_mask_utils.py#L256-L260

is_tracing = (
    torch.jit.is_tracing()
    or isinstance(inputs_embeds, torch.fx.Proxy)
    or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)

So here is a candidate for how we might currently be able to run BERT after fixing #805 (and subsequent bugs):

import transformers, thunder, torch
@thunder.core.jit_ext.register_general_jit_lookaside(
    transformers.modeling_utils.PreTrainedModel.warn_if_padding_and_no_attention_mask
)
@thunder.core.jit_ext.interpreter_needs_wrap
def dummy(*args):
    pass

@thunder.core.jit_ext.register_general_jit_lookaside(
    torch._dynamo.is_compiling
)
@thunder.core.jit_ext.interpreter_needs_wrap
def is_compiling():
    return True

m = transformers.BertForSequenceClassification(transformers.BertConfig())
inp = torch.randint(1, 20, (1, 32))
jm = thunder.jit(m)
jm(inp)

We might submit an issue to transformers to check for is_tracing in transformers.modeling_utils.PreTrainedModel.warn_if_padding_and_no_attention_mask
but in general I wonder how to provide the such compatibility lookasides to users.

possible variants:

  • A compat executor,
  • detecting the module we are tracing and adding them just for transformers (too much magic),
  • always having the torch._dynamo.is_compiling-lookaside?
    ...
@t-vi t-vi added interpreter program-coverage Requests for model and program coverage design This is a largish feature / design labels Jul 22, 2024
@IvanYashchuk
Copy link
Collaborator

PyTorch is working on providing a public interface to queries like "Is Dynamo active now?", but it's returning just False for now:
https://github.com/pytorch/pytorch/blob/5c78581fc91aa50673790a7f591294a19b489e20/torch/compiler/__init__.py#L234

@tfogal tfogal self-assigned this Jul 22, 2024
@tfogal
Copy link
Collaborator

tfogal commented Jul 22, 2024

triage review:

  • we need some sort of magic to support transformers because they investigate input data unless inside dynamo
  • do we ask the user to do something before the transformer, or as a jit arg?
  • or do we try to detect that we're in transformers and do something?
  • maybe we just have a lookaside for 'is dynamo active'?
  • dynamo devs are creating a more public interface for querying whether or not we're in that compiler
  • would be great to participate in their interface design here
  • let's reach out to Meta, @tfogal to figure out a person here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design This is a largish feature / design interpreter program-coverage Requests for model and program coverage
Projects
None yet
Development

No branches or pull requests

3 participants