-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HF LLaVa support #1174
Comments
Note that this looks pretty bad from a "data dependent control flow perspective" and has, indeed, been changed in transformers four months ago. |
Indeed it does look kinda bad :( |
Right, I'm stupid. They changed it for modelling_llava_next.py not modelling_llava.py. :( |
Updated the description with the relevant blocking issues. |
@kshitij12345 does the splitter correctly route these ops to the inductor path? |
thunderFX side-steps the data-dependent ops and works on the above snippet. import torch
import thunder
from transformers import LlavaForConditionalGeneration
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
)
model.to("cuda")
input_ids = torch.randint(1, 100, (1, 22), device="cuda")
attention_mask = torch.rand((1, 22), device="cuda") > 0.5
pixel_values = torch.randn((1, 3, 336, 336), device="cuda", dtype=torch.bfloat16, requires_grad=True)
labels = torch.randint(0, 100, (1, 22), device="cuda")
# Setup fake image id
input_ids[0, 0] = 1
input_ids[0, 5] = 32000
# # model = thunder.jit(model, executors=thunder.get_default_executors())
# model = torch.compile(model)
import thunder.dynamo
backend = thunder.dynamo.ThunderCompiler(executors=thunder.get_default_executors())
model = torch.compile(model, backend=backend)
out = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, labels=labels)
print(out.loss) # Loss is detached from the graph. However, I see that |
🚀 Model / language coverage
The idea is to support LLaVa model from HF. This issue is mainly for tracking the status.
Blocking issues:
torch.where(condition)
with thunder.jit #124Minimal Repro
First of all get the
transformers
library withpip install transformers
then run this script:The text was updated successfully, but these errors were encountered: