Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

__bool__ (and data dependent control flow) #735

Open
Tracked by #1174
t-vi opened this issue Jul 9, 2024 · 1 comment
Open
Tracked by #1174

__bool__ (and data dependent control flow) #735

t-vi opened this issue Jul 9, 2024 · 1 comment
Labels
program-coverage Requests for model and program coverage

Comments

@t-vi
Copy link
Collaborator

t-vi commented Jul 9, 2024

HF BERT data-dependent control flow:

if self.config.pad_token_id in input_ids[:, [-1, 0]]:
   4348     warn_string = (
   4349         "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
   4350         "https://huggingface.co/docs/transformers/troubleshooting"
   4351         "#incorrect-output-when-padding-tokens-arent-masked."
   4352     )
   4354     # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
   4355     # attention_mask or not. In this case, we should still show a warning because this is a rare case.

input_ids is a tensor, that ultimately makes us fail on __bool__ for tensors.

Repro:

import torch, thunder, transformers

m = transformers.BertForSequenceClassification(transformers.BertConfig())
jm = thunder.jit(m)
a = torch.randint(1, 20, (1, 25))
jm(a)
@t-vi t-vi added the program-coverage Requests for model and program coverage label Jul 9, 2024
@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Sep 25, 2024

Same issue as for #1174, it seems that the first step here would be to implement the bool ops on tensors. Some repro snippets:

import torch
import thunder

def foo(x):
    return not x

jf = thunder.jit(foo)
a = torch.tensor(0)

jf(a)

and also something like this:

def bar(x):
    return x or False

Stack trace for reference:

  File "/opt/pytorch/lightning-thunder/test.py", line 10, in <module>
    jf(a)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 717, in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/langctxs.py", line 136, in _fn
    result = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 219, in cache_info_wrapper
    res = fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/__init__.py", line 506, in get_computation_and_inputs
    jit_results: TraceResults = thunder_general_jit(
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 1635, in thunder_general_jit
    result = jfn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 7189, in fn_
    raise e
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 7150, in fn_2
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/test.py", line 5, in foo
    return not x
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 5976, in impl
    if bool(tos):
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1387, in impl
    return dunder_bool(x)
  File "/opt/pytorch/lightning-thunder/thunder/core/interpreter.py", line 1292, in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/jit_ext.py", line 387, in wrapper
    return fn(*args, **kwargs)
  File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 1646, in __bool__
    raise NotImplementedError
NotImplementedError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
program-coverage Requests for model and program coverage
Projects
None yet
Development

No branches or pull requests

2 participants