Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HF Llama 1B 1 Layer slowness (inference) #1467

Open
t-vi opened this issue Nov 25, 2024 · 15 comments
Open

HF Llama 1B 1 Layer slowness (inference) #1467

t-vi opened this issue Nov 25, 2024 · 15 comments
Assignees

Comments

@t-vi
Copy link
Collaborator

t-vi commented Nov 25, 2024

A 1 layer hf llama 1b with 1 layer is too slow in Thunder.

Repro for instantiating the model:

import torch
from transformers.models.llama import LlamaForCausalLM, LlamaConfig

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}


config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()

args = dict(
    cache_positions=torch.arange(6, device="cuda"),
    input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
    attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
    inputs_embeds=None,
    use_cache=True,
    return_dict=True,
)

import thunder
from thunder.transforms.cudagraph import CUDAGraphTransform
jm = thunder.jit(model,
                #executors=('apex', 'cudnn', 'sdpa', 'torchcompile_cat', 'nvfuser'),
                ) #, transforms=(CUDAGraphTransform(),))

res = jm(**args)

Timings on L40s in a studio (PyTorch 2.5.1 and NVFuser from pip nightlies).

%timeit jm(**args); torch.cuda.synchronize()
2.19 ms ± 4.41 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit model(**args); torch.cuda.synchronize()
2.05 ms ± 3.11 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
cm = torch.compile(model)
cm(**args);
%timeit cm(**args); torch.cuda.synchronize()
1.36 ms ± 9.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

This is after #1465 , before things were looking worse.

Currently we see 7 separate nvfusion regions:

import collections
collections.Counter(bsym.sym.name for bsym in thunder.last_traces(jm)[-1].bound_symbols)

gives

Counter({'python_del': 18,
         'unpack_trivial': 15,
         'linear': 8,
         'nvFusion0': 1,
         'embedding': 1,
         'triu': 1,
         'matmul': 1,
         'TorchCompile0': 1,
         'clone': 1,
         'nvFusion1': 1,
         'copy_with_setitem_impl': 1,
         'TorchCompile1': 1,
         'nvFusion2': 1,
         'sdpaex_grad_forward_scaled_dot_product_efficient_attention': 1,
         'nvFusion3': 1,
         'nvFusion4': 1,
         'nvFusion5': 1,
         'nvFusion6': 1,
         'python_return': 1})

cc @apaz-cli @tfogal

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Nov 25, 2024

Instantiating a model config from a model name requires a token and approval in HF Hub. Could you please use an explicit dictionary to instantiate the model?
I tried following https://github.com/Lightning-AI/litgpt/blob/22528bf29c116fb70b210a0c574507d2c62f6619/litgpt/config.py#L662-L681
and came up with

from transformers.models.llama import LlamaConfig

config = LlamaConfig(
    vocab_size=128000,
    hidden_size=2048,
    intermediate_size=8192,
    num_hidden_layers=16,
    num_attention_heads=32,
    num_key_value_heads=4,
    max_position_embeddings=131072,
    rope_theta=500000.0,
)

Is this config correct?

@t-vi
Copy link
Collaborator Author

t-vi commented Nov 25, 2024

I updated the repro from the test (test_networks::testhf_llama).

@mruberry
Copy link
Collaborator

@kevinstephano to find someone to investigate

@t-vi
Copy link
Collaborator Author

t-vi commented Nov 25, 2024

Note that you need current (as of writing) transformers (I'm using == 6.46.3, works much better than 6.45.x).

@kevinstephano
Copy link
Collaborator

This is a comparison on DGX H100:

Execution Type Wall Clock Time (ms) CPU Overhead (ms) Kernel Time (ms) Kernels Overhead / Kernel (us)
Thunder-nvFuser 1.281 0.984 0.297 32 30.7
Thunder-torch 1.232 0.814 0.413 77 10.5
torch.compile 0.455 0.163 0.292 24 6.8
torch-eager 1.014 0.616 0.398 65 9.4

@kevinstephano
Copy link
Collaborator

The nvFuser issue is that no matter whether there is segmentation or not, nvFuser has an overhead, currently, of around 20 to 30 us per kernel which is way more than the 1 us time of most kernels in the repro example. nvFuser would need to re-architect its hot path to return a sequence of kernels and launch parameters instead of going through its cache hierarchy, currently, that is too expensive for inference.

Another thing that is noticeable in Thunder is that there is ~250 us of startup time that is not occurring in torch.compile or torch-eager that does not allow the remaining CPU overhead to overlap with the large GEMM at the end of the previous step. This difference might be the difference in CPU overhead between Thunder-torch and torch.eager, although, there is a difference in the number of kernels.

Screenshot 2024-11-30 at 10 10 17

@mruberry
Copy link
Collaborator

mruberry commented Dec 2, 2024

The nvFuser issue is that no matter whether there is segmentation or not, nvFuser has an overhead, currently, of around 20 to 30 us per kernel which is way more than the 1 us time of most kernels in the repro example. nvFuser would need to re-architect its hot path to return a sequence of kernels and launch parameters instead of going through its cache hierarchy, currently, that is too expensive for inference.

This is a great point. Is there a single issue we should track for updates on this?

Another thing that is noticeable in Thunder is that there is ~250 us of startup time that is not occurring in torch.compile or torch-eager that does not allow the remaining CPU overhead to overlap with the large GEMM at the end of the previous step. This difference might be the difference in CPU overhead between Thunder-torch and torch.eager, although, there is a difference in the number of kernels.

Interesting! I wonder what's going on here? @t-vi, @tfogal, could this be our cache time?

@kevinstephano
Copy link
Collaborator

The nvFuser issue is that no matter whether there is segmentation or not, nvFuser has an overhead, currently, of around 20 to 30 us per kernel which is way more than the 1 us time of most kernels in the repro example. nvFuser would need to re-architect its hot path to return a sequence of kernels and launch parameters instead of going through its cache hierarchy, currently, that is too expensive for inference.

This is a great point. Is there a single issue we should track for updates on this?

We don't have an open issue. This latency was known but it was previously covered up by Thunder's own latencies. It looks like Thunder's latencies have been cleaned up, except for the ~250 us startup time of each step, which is why the nvFuser latency is now prominent. The reason why it was not previously addressed were as follows:

  1. The thunder latency was so dominant
  2. We are focused on training and not inference.
  3. The strategy was to use Cuda Graphs for inference.

@mruberry
Copy link
Collaborator

mruberry commented Dec 2, 2024

The nvFuser issue is that no matter whether there is segmentation or not, nvFuser has an overhead, currently, of around 20 to 30 us per kernel which is way more than the 1 us time of most kernels in the repro example. nvFuser would need to re-architect its hot path to return a sequence of kernels and launch parameters instead of going through its cache hierarchy, currently, that is too expensive for inference.
This is a great point. Is there a single issue we should track for updates on this?

We don't have an open issue. This latency was known but it was previously covered up by Thunder's own latencies. It looks like Thunder's latencies have been cleaned up, except for the ~250 us startup time of each step, which is why the nvFuser latency is now prominent. The reason why it was not previously addressed were as follows:

  1. The thunder latency was so dominant
  2. We are focused on training and not inference.
  3. The strategy was to use Cuda Graphs for inference.

That makes a lot of sense, and it still does! While we may not need to prioritize the work, would you please file an issue so it can be tracked and discussed directly?

@kevinstephano
Copy link
Collaborator

That makes a lot of sense, and it still does! While we may not need to prioritize the work, would you please file an issue so it can be tracked and discussed directly?

Here is the nvFuser issue.

@t-vi t-vi changed the title HF Llama 1B 1 Layer slowness HF Llama 1B 1 Layer slowness (inference) Dec 2, 2024
@kevinstephano
Copy link
Collaborator

I attempted to measure latency on various scenarios and I wasn't able to measure Thunder+torch.compile. Is anyone else able to do so?

import torch
import thunder
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from typing import Callable
from functools import partial, wraps
from collections import OrderedDict

LLAMA_3_2_1B_CFG = {
    "architectures": ["LlamaForCausalLM"],
    "attention_bias": False,
    "attention_dropout": 0.0,
    "bos_token_id": 128000,
    "eos_token_id": 128001,
    "head_dim": 64,
    "hidden_act": "silu",
    "hidden_size": 2048,
    "initializer_range": 0.02,
    "intermediate_size": 8192,
    "max_position_embeddings": 131072,
    "mlp_bias": False,
    "model_type": "llama",
    "num_attention_heads": 32,
    "num_hidden_layers": 16,
    "num_key_value_heads": 8,
    "pretraining_tp": 1,
    "rms_norm_eps": 1e-05,
    "rope_scaling": {
        "factor": 32.0,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": 8192,
        "rope_type": "llama3",
    },
    "rope_theta": 500000.0,
    "tie_word_embeddings": True,
    "torch_dtype": "bfloat16",
    "transformers_version": "4.45.0.dev0",
    "use_cache": True,
    "vocab_size": 128256,
    "_commit_hash": "4e20de362430cd3b72f300e6b0f18e50e7166e08",
}


config = LlamaConfig(**LLAMA_3_2_1B_CFG)
config.num_hidden_layers = 1

with torch.device("cuda"):
    model = LlamaForCausalLM(config).to(torch.bfloat16).requires_grad_(False).eval()

args = dict(
    cache_positions=torch.arange(6, device="cuda"),
    input_ids=torch.tensor([[128000, 791, 1401, 311, 2324, 374]], device="cuda"),
    attention_mask=torch.ones(1, 6, dtype=torch.int64, device="cuda"),
    inputs_embeds=None,
    use_cache=True,
    return_dict=True,
)

def cuda_timer(warmup_iters: int = 10, timing_iters: int = 40):
    def decorator(fn: Callable) -> Callable:
        @wraps(fn)
        def wrapper(*args, **kwargs) -> float:
            for _ in range(warmup_iters):
                fn(*args, **kwargs)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            torch.cuda.synchronize()
            start.record()
            for _ in range(timing_iters):
                fn(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()

            kernel_time = start.elapsed_time(end) / timing_iters
            return kernel_time
        return wrapper
    return decorator

@cuda_timer()
def run_model(mymodel, args) :
    res = mymodel(**args)

def eager(fn):
    return fn

executors = OrderedDict()
executors['Thunder-nvFuser'] = thunder.jit
#executors['Thunder-torch.compile'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa', 'torchcompile'])
executors['Thunder-torch'] = partial(thunder.jit, executors=['apex', 'cudnn', 'sdpa'])
executors['torch.compile'] = torch.compile
executors['torch-eager'] = eager

#print(inspect.signature(model.forward, follow_wrapped=True))

for name, func in executors.items():
    exec_model = func(model)
    kernel_time = run_model(exec_model, args)
    print(f"{name} {kernel_time:.03f} ms")

DGX H100 Results:

Execution Type Wall Clock Time (ms)
Thunder-nvFuser 1.082
Thunder-torch.compile N/A
Thunder-torch 0.965
torch.compile 0.313
torch-eager 0.834

@t-vi
Copy link
Collaborator Author

t-vi commented Dec 3, 2024

Is anyone else able to do so?

We ran into these, too, I hope that it will be possible after #1500 is merged.

@kevinstephano
Copy link
Collaborator

kevinstephano commented Dec 9, 2024

I explored the Inference performance difference between Thunder-nvFuser and Thunder-torch.compile. The difference is that Thunder-torch.compile is adding Linears and matmuls into its fusion regions and, therefore, is avoiding Thunder’s overheads per fusion or operator. In the NSight Systems graph below you can see that Thunder-torch.compile breaks up an inference step into 3 regions where two are executed by torch.compile and sdpa executes the remaining region. torch.compile is not fusing the matmuls and is just calling the ATen/Cublas matmuls.

I will note that the per operator overhead does not look like the issue as both Thunder-nvFuser and Thunder-torch.compile have similar overheads for a single Add operation, in isolation.

DGX H100-80GB Results:

Execution Type Wall Clock Time (ms)
Thunder-nvFuser 0.072
Thunder-torch.compile 0.075
Thunder-torch 0.042
torch.compile 0.026
torch-eager 0.007

There are a few things we can do in the nvFuser execution path to address overheads:

  • [This is the primary issue] Add more operators to nvFuser's fusion region
    • Execute Matmuls and Linears through nvFuser’s fallback path. Ivan has a PR up to do this.
    • Adding Matmuls and Linears, alone, to the fusion region does not decrease the overheads enough. Other things we would likely have to do:
      • Take over torch.nn.embedding through a fallback path mechanism until it is implemented in nvFuser.
      • torch.matmul
      • torch.triu
      • Get RoPE to execute through nvFuser as well. It is currently blocked by a bug when I tried.
  • The FusionDefinition caching from Thunder looks slow on the order of 20+us per fusion. We could try some other caching, possibly in C++.
  • In nvFuser’s code base:
    • The python execute() wrapper in nvFuser has 6 - 20 us of overhead depending on the size of the fusion. I am going to try and isolate the overhead.
    • The code that we use to update our Tensor pointers in nvFuser is using expression evaluator to recalculate the pointer offsets which is not necessary if the shapes don’t change. This is taking 6 us for simple fusions and scales with the number of arguments to the kernel. I have a potential change that would avoid expression evaluator usage when seeing the same shaped tensors.
Screenshot 2024-12-05 at 07 46 42

@mruberry
Copy link
Collaborator

mruberry commented Dec 9, 2024

Adding a couple notes from our discussions today:

  • We may want to create fusion regions of limited size so that nvFuser can accept linears but not run into cases like nvFuser linear fusion leads to notebook example timeout #1490
  • At the same time, we may want to require fusion regions have a certain size so the fusion is worthwhile
  • There are other sources of latency outside nvFuser, a few that come to mind are the prologue (both the validation part and the extraction part), the thunder->nvfuser cache, and the Python interpreter overhead. It is unlikely we can significantly reduce Python interpreter overhead, although we could experiment with some reductions.

@kevinstephano
Copy link
Collaborator

kevinstephano commented Dec 10, 2024

The items we are going to address in nvFuser:

  • Consume more ops through nvFuser
    • Matmul
    • torch.nn.embedding
    • torch.triu
  • The code that we use to update our Tensor pointers in nvFuser is using expression evaluator to recalculate the pointer offsets which is not necessary if the shapes don’t change. This is taking 6 us for simple fusions and scales with the number of arguments to the kernel. I have a potential change that would avoid expression evaluator usage when seeing the same shaped tensors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants