Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF LLaVa support #1174

Open
2 tasks
riccardofelluga opened this issue Sep 19, 2024 · 6 comments
Open
2 tasks

HF LLaVa support #1174

riccardofelluga opened this issue Sep 19, 2024 · 6 comments
Assignees
Labels
hf-transformers program-coverage Requests for model and program coverage

Comments

@riccardofelluga
Copy link
Collaborator

riccardofelluga commented Sep 19, 2024

🚀 Model / language coverage

The idea is to support LLaVa model from HF. This issue is mainly for tracking the status.

Blocking issues:

Minimal Repro

First of all get the transformers library with pip install transformers then run this script:

import torch
import thunder
from transformers import LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
)
model.to("cuda")

input_ids = torch.randint(1, 32000, (1, 22), device="cuda")
attention_mask = torch.ones((1, 22), dtype=torch.int64, device="cuda")
pixel_values = torch.randn((1, 3, 336, 336), device="cuda")
labels = torch.randint(-100, 32000, (1, 22), device="cuda")

# Setup fake image id
input_ids[0, 0] = 1
input_ids[0, 5] = 32000

model = thunder.jit(model, executors=thunder.get_default_executors())

out = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, labels=labels)
@riccardofelluga riccardofelluga added program-coverage Requests for model and program coverage hf-transformers labels Sep 19, 2024
@riccardofelluga riccardofelluga self-assigned this Sep 19, 2024
@t-vi
Copy link
Collaborator

t-vi commented Sep 19, 2024

left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))

Note that this looks pretty bad from a "data dependent control flow perspective" and has, indeed, been changed in transformers four months ago.

@riccardofelluga
Copy link
Collaborator Author

@t-vi

Note that this looks pretty bad from a "data dependent control flow perspective" and has, indeed, been changed in transformers four months ago.

Indeed it does look kinda bad :(
What do you mean by it has been changed? the line seems to still be there in the file:

https://github.com/huggingface/transformers/blob/4d8908df272c0a9db2e5fbcc8aaed73cdf75442a/src/transformers/models/llava/modeling_llava.py#L284

@t-vi
Copy link
Collaborator

t-vi commented Sep 19, 2024

Right, I'm stupid. They changed it for modelling_llava_next.py not modelling_llava.py. :(

@riccardofelluga
Copy link
Collaborator Author

Updated the description with the relevant blocking issues.

@csarofeen
Copy link
Collaborator

@kshitij12345 does the splitter correctly route these ops to the inductor path?

@kshitij12345
Copy link
Collaborator

kshitij12345 commented Sep 30, 2024

thunderFX side-steps the data-dependent ops and works on the above snippet.

import torch
import thunder
from transformers import LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
)
model.to("cuda")

input_ids = torch.randint(1, 100, (1, 22), device="cuda")
attention_mask = torch.rand((1, 22), device="cuda") > 0.5
pixel_values = torch.randn((1, 3, 336, 336), device="cuda", dtype=torch.bfloat16, requires_grad=True)
labels = torch.randint(0, 100, (1, 22), device="cuda")

# Setup fake image id
input_ids[0, 0] = 1
input_ids[0, 5] = 32000

# # model = thunder.jit(model, executors=thunder.get_default_executors())
# model = torch.compile(model)

import thunder.dynamo
backend = thunder.dynamo.ThunderCompiler(executors=thunder.get_default_executors())
model = torch.compile(model, backend=backend)

out = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, labels=labels)
print(out.loss)  # Loss is detached from the graph.

However, I see that out.loss is detached from the computation graph and we can't call backward on it. This is because of a bug in splitter as it doesn't correctly deal with regions under torch.no_grad. Will file a separate issue for the same and look into fixing it. (EDIT - Issue filed at #1219)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
hf-transformers program-coverage Requests for model and program coverage
Projects
None yet
Development

No branches or pull requests

4 participants