Replies: 3 comments 4 replies
-
This could dramatically revolutionize speculative decoding... atm a huge bottleneck is running a small batch of 2-4 tokens through the larger model. |
Beta Was this translation helpful? Give feedback.
-
I modified the benchmark slightly to make sure the prompt cache is fixed length. import time
import copy
import mlx.core as mx
import mlx_lm
from mlx_lm.models import cache
#model, tokenizer = mlx_lm.load("mlx-community/Llama-3.2-1B-Instruct-bf16")
model, tokenizer = mlx_lm.load("mlx-community/Llama-3.2-1B-Instruct-4bit")
#model, tokenizer = mlx_lm.load("mlx-community/Meta-Llama-3.1-8B-Instruct-4bit")
prompt_cache = cache.make_prompt_cache(model)
prompt = mx.array([[100] * 256])
logits = model(prompt, cache=prompt_cache)
mx.eval(logits)
seqlen = 1
print("Seqlen | Toks/s | Fwd Time (ms)")
print("------ | ------ | -------------")
while seqlen <= 32:
inp = mx.array([[100] * seqlen])
tic = time.perf_counter()
its = 25
for _ in range(its):
logits = model(inp, cache=copy.deepcopy(prompt_cache))
mx.eval(logits)
toc = time.perf_counter()
s = (toc - tic) / its
tps = seqlen / s
ms = 1000 * s
print(f"{seqlen} | {tps:.3f} | {ms:.3f}")
seqlen *= 2 Overall the results are not that bad. For the quantized 1B model I see a nice increase in toks/sec from
|
Beta Was this translation helpful? Give feedback.
-
related llama.cpp PR ggerganov/llama.cpp#10581 |
Beta Was this translation helpful? Give feedback.
-
Hi, since transformer inference is memory bound, when increasing the numbers of processed tokens the forward time should behave in an increasing step-wise manner.
In the following image I show the median forward times with 4-bit quantized weights for some sequence lengths, notably there is a 4.3x increase when going from sequence length 1 to 8, while the time between 8 and 32 tokens is along the same line. This shows that the QMM kernels for small lengths are likely under-optimized.
As comparison here are the results for FP16 weights. Where the forward for lengths between 1 and 64 are in the same step/wave.
![Screenshot 2024-11-16 at 17 26 53](https://private-user-images.githubusercontent.com/98891648/386883429-5eb9e24f-073c-4412-874b-6370ff583f3d.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkyODgyMzUsIm5iZiI6MTczOTI4NzkzNSwicGF0aCI6Ii85ODg5MTY0OC8zODY4ODM0MjktNWViOWUyNGYtMDczYy00NDEyLTg3NGItNjM3MGZmNTgzZjNkLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTElMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjExVDE1MzIxNVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPThhNjI0YjE2YTU5NzYxNGI4ODU5ODk5NTEwYTM5NDE4ZmQ2NTNmOWU3NjkyZDE4YTJmYTNmMWE0NTZjYmE1ZTQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.q8A20YLhDo4Ri2cZ-ic41TvzuHZ_Fo6mBFP1QnUgg2A)
I'm pretty uncertain about this, but based on the last image, shouldn't we expect that the QMM times for lengths between 1 and 64/4=16 to be in the same step and not several times more?
Experiments done on a M1 Macbook Air 8GB.
mlx 0.20.0
and0.19.3
MacOS 15.2 Beta,Llama 3.2-1B-Instruct-4bit/bf16
with the following code:Beta Was this translation helpful? Give feedback.
All reactions