-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HF Llama 1B 1 Layer slowness (inference) #1467
Comments
Instantiating a model config from a model name requires a token and approval in HF Hub. Could you please use an explicit dictionary to instantiate the model? from transformers.models.llama import LlamaConfig
config = LlamaConfig(
vocab_size=128000,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=16,
num_attention_heads=32,
num_key_value_heads=4,
max_position_embeddings=131072,
rope_theta=500000.0,
) Is this config correct? |
I updated the repro from the test (test_networks::testhf_llama). |
@kevinstephano to find someone to investigate |
Note that you need current (as of writing) transformers (I'm using |
This is a comparison on DGX H100:
|
This is a great point. Is there a single issue we should track for updates on this?
Interesting! I wonder what's going on here? @t-vi, @tfogal, could this be our cache time? |
We don't have an open issue. This latency was known but it was previously covered up by Thunder's own latencies. It looks like Thunder's latencies have been cleaned up, except for the ~250 us startup time of each step, which is why the nvFuser latency is now prominent. The reason why it was not previously addressed were as follows:
|
That makes a lot of sense, and it still does! While we may not need to prioritize the work, would you please file an issue so it can be tracked and discussed directly? |
Here is the nvFuser issue. |
I attempted to measure latency on various scenarios and I wasn't able to measure import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps
from collections import OrderedDict
LLAMA_3_2_1B_CFG = {
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"head_dim": 64,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 131072,
"mlp_bias": False,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 16,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 32.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
"rope_theta": 500000.0,
"tie_word_embeddings": True,
"torch_dtype": "bfloat16",
"transformers_version": "4.45.0.dev0",
"use_cache": True,
"vocab_size": 128256,
"_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}
config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1
with torch.device("cuda"):
model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()
args = dict(
cache_positions=torch.arange(6, device="cuda"),
input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
inputs_embeds=None,
use_cache=True,
return_dict=True,
)
def cuda_timer(warmup_iters: int = 10, timing_iters: int = 40):
def decorator(fn: Callable) -> Callable:
@wraps(fn)
def wrapper(*args, **kwargs) -> float:
for _ in range(warmup_iters):
fn(*args, **kwargs)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(timing_iters):
fn(*args, **kwargs)
end.record()
torch.cuda.synchronize()
kernel_time = start.elapsed_time(end) / timing_iters
return kernel_time
return wrapper
return decorator
@cuda_timer()
def run_model(mymodel, args) :
res = mymodel(**args)
def eager(fn):
return fn
executors = OrderedDict()
executors['Thunder-nvFuser'] = thunder.jit
#executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile
executors['torch-eager'] = eager
#print(inspect.signature(model.forward, follow_wrapped=True))
for name, func in executors.items():
exec_model = func(model)
kernel_time = run_model(exec_model, args)
print(f"{name} {kernel_time:.03f} ms") DGX H100 Results:
|
We ran into these, too, I hope that it will be possible after #1500 is merged. |
I explored the Inference performance difference between I will note that the per operator overhead does not look like the issue as both DGX H100-80GB Results:
There are a few things we can do in the nvFuser execution path to address overheads:
|
Adding a couple notes from our discussions today:
|
The items we are going to address in nvFuser:
|
A 1 layer hf llama 1b with 1 layer is too slow in Thunder.
Repro for instantiating the model:
Timings on L40s in a studio (PyTorch 2.5.1 and NVFuser from pip nightlies).
This is after #1465 , before things were looking worse.
Currently we see 7 separate nvfusion regions:
gives
cc @apaz-cli @tfogal
The text was updated successfully, but these errors were encountered: