Skip to content

Commit

Permalink
Quantize embedding / Update quantize API (ml-explore#680)
Browse files Browse the repository at this point in the history
* more async eval

* quantize embedding / update quantize api

* more updates for quantize

* update for quantize embeddings

* update sd quant API

* update sdxl quants

* error for datasets < batch_size

* async

* fix config loading

* fix quant

* fix tests

* fix req

* remove lm head if tie weights is true

* fix test
  • Loading branch information
awni authored Apr 19, 2024
1 parent f5f189e commit 2146bcd
Show file tree
Hide file tree
Showing 28 changed files with 108 additions and 190 deletions.
20 changes: 6 additions & 14 deletions llms/gguf_llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,24 +280,16 @@ def load(gguf_file: str, repo: str = None):
config = get_config(metadata)
model = Model(ModelArgs(**config))
if quantization is not None:
# quantized the LM head?
qm = model if "lm_head.scales" in weights else model.model
nn.QuantizedLinear.quantize_module(
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
qm,
**quantization,
class_predicate=class_predicate,
)

def dequantize(k):
weight = weights.pop(f"{k}.weight")
scales = weights.pop(f"{k}.scales")
biases = weights.pop(f"{k}.biases")
weights[f"{k}.weight"] = mx.dequantize(
weight, scales=scales, biases=biases, **quantization
)

# Dequantize embeddings
dequantize("model.embed_tokens")

tokenizer = GGUFTokenizer(metadata)
model.load_weights(list(weights.items()))
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/llama/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items())))

# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
nn.quantize(model, args.q_group_size, args.q_bits)

# Update the config:
quantized_config["quantization"] = {
Expand Down
2 changes: 1 addition & 1 deletion llms/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def load_model(model_path):
quantization = config.pop("quantization", None)
model = Llama(ModelArgs(**config))
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)
model.update(tree_unflatten(list(weights.items())))
tokenizer = SentencePieceProcessor(model_file=str(model_path / "tokenizer.model"))
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/llama/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.8.0
mlx>=0.11.0
sentencepiece
torch
numpy
2 changes: 1 addition & 1 deletion llms/mistral/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def quantize(weights, config, args):
model.update(tree_unflatten(list(weights.items())))

# Quantize the model:
nn.QuantizedLinear.quantize_module(model, args.q_group_size, args.q_bits)
nn.quantize(model, args.q_group_size, args.q_bits)

# Update the config:
quantized_config["quantization"] = {
Expand Down
2 changes: 1 addition & 1 deletion llms/mistral/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def load_model(folder: str):
weights = tree_unflatten(list(weights.items()))
model = Mistral(model_args)
if quantization is not None:
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)
model.update(weights)
mx.eval(model.parameters())
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/mistral/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.8.0
mlx>=0.11.0
sentencepiece
torch
numpy
5 changes: 1 addition & 4 deletions llms/mixtral/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,10 @@ def quantize(weights, config, args):
model.update(all_weights)

# Quantize the model:
nn.QuantizedLinear.quantize_module(
nn.quantize(
model,
args.q_group_size,
args.q_bits,
# TODO: Quantize gate matrices when < 32 tiles supported
linear_class_predicate=lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0] != 8,
)

# Extract the subset of quantized weights:
Expand Down
6 changes: 1 addition & 5 deletions llms/mixtral/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,7 @@ def load_model(folder: str):
weights = tree_unflatten(list(weights.items()))
model = Mixtral(model_args)
if quantization is not None:
# TODO: Quantize gate matrices when < 32 tiles supported
quantization["linear_class_predicate"] = (
lambda m: isinstance(m, nn.Linear) and m.weight.shape[0] != 8
)
nn.QuantizedLinear.quantize_module(model, **quantization)
nn.quantize(model, **quantization)

model.update(weights)
return model, tokenizer
Expand Down
2 changes: 1 addition & 1 deletion llms/mixtral/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.8.0
mlx>=0.11.0
sentencepiece
torch
numpy
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __call__(
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out, cache

Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def __call__(
cache=None,
):
out, cache = self.model(inputs, cache)
out = out @ self.model.embed_tokens.weight.T
out = self.model.embed_tokens.as_linear(out)
return out, cache

@property
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/models/olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __call__(
h = self.norm(h)

if self.weight_tying:
return h @ self.wte.weight.T, cache
return self.wte.as_linear(h), cache

return self.ff_out(h), cache

Expand Down
13 changes: 9 additions & 4 deletions llms/mlx_lm/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,19 +172,24 @@ def __init__(self, args: ModelArgs):
self.args = args
self.model_type = args.model_type
self.model = Qwen2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache

def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
if self.args.tie_word_embeddings:
weights.pop("lm_head.weight", None)
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
Expand Down
14 changes: 7 additions & 7 deletions llms/mlx_lm/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,20 @@ def __init__(self, args: ModelArgs):
self.args = args
self.model_type = args.model_type
self.model = Starcoder2Model(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
if not args.tie_word_embeddings:
sself.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
):
out, cache = self.model(inputs, cache)
return self.lm_head(out), cache

def sanitize(self, weights):
if self.args.tie_word_embeddings and "lm_head.weight" not in weights:
weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
return weights
if self.args.tie_word_embeddings:
out = self.model.embed_tokens.as_linear(out)
else:
out = self.lm_head(out)
return out, cache

@property
def layers(self):
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlx>=0.10
mlx>=0.11
numpy
transformers>=4.39.3
protobuf
Expand Down
1 change: 1 addition & 0 deletions llms/mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):

def __init__(self, tokenizer):
self._tokenizer = tokenizer
self._tokenizer.decode([0])
self.reset()

def reset(self):
Expand Down
5 changes: 5 additions & 0 deletions llms/mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def default_loss(model, inputs, targets, lengths):
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)

# Make the batches:
batch_idx = [
Expand Down
80 changes: 33 additions & 47 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizer

# Local imports
from .sample_utils import top_p_sampling
Expand All @@ -31,12 +31,6 @@

MAX_FILE_SIZE_GB = 5

linear_class_predicate = (
lambda m: isinstance(m, nn.Linear)
and m.weight.shape[0]
!= 8 # avoid quantizing gate layers, otherwise we have to re-quant and upload all the mixtral models
)


def _get_classes(config: dict):
"""
Expand Down Expand Up @@ -188,14 +182,14 @@ def _step(y):
repetition_context = repetition_context[-repetition_context_size:]
return y, prob

y, prob = _step(y)
y, p = _step(y)

mx.async_eval(y)
while True:
sync = mx.async_eval(y)
next_out = _step(y)
sync.wait()
yield y.item(), prob
y, prob = next_out
next_y, next_p = _step(y)
mx.async_eval(next_y)
yield y.item(), p
y, p = next_y, next_p


def generate(
Expand Down Expand Up @@ -283,6 +277,16 @@ def generate(
return detokenizer.text


def load_config(model_path: Path) -> dict:
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise
return config


def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
"""
Load and initialize the model from a given path.
Expand All @@ -300,13 +304,8 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
FileNotFoundError: If the weight files (.safetensors) are not found.
ValueError: If the model class or args class are not found or cannot be instantiated.
"""
try:
with open(model_path / "config.json", "r") as f:
config = json.load(f)
quantization = config.get("quantization", None)
except FileNotFoundError:
logging.error(f"Config file not found in {model_path}")
raise

config = load_config(model_path)

weight_files = glob.glob(str(model_path / "*.safetensors"))
if not weight_files:
Expand All @@ -325,26 +324,17 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
if hasattr(model, "sanitize"):
weights = model.sanitize(weights)

if quantization is not None:
# for legacy models that don't have lm_head quant due to non-32 dims
if "lm_head.scales" not in weights.keys():
vocab_size = config["vocab_size"]
extended_linear_class_predicate = (
lambda layer: linear_class_predicate(layer)
and layer.weight.shape[0] != vocab_size
)
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=extended_linear_class_predicate,
)
# for models that have lm_head quant
else:
nn.QuantizedLinear.quantize_module(
model,
**quantization,
linear_class_predicate=linear_class_predicate,
)
if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
class_predicate = (
lambda p, m: isinstance(m, (nn.Linear, nn.Embedding))
and f"{p}.scales" in weights
)
nn.quantize(
model,
**quantization,
class_predicate=class_predicate,
)

model.load_weights(list(weights.items()))

Expand Down Expand Up @@ -395,10 +385,9 @@ def fetch_from_hub(
model_path: Path, lazy: bool = False
) -> Tuple[nn.Module, dict, PreTrainedTokenizer]:
model = load_model(model_path, lazy)
config = AutoConfig.from_pretrained(model_path)
config = load_config(model_path)
tokenizer = load_tokenizer(model_path)

return model, config.to_dict(), tokenizer
return model, config, tokenizer


def make_shards(weights: dict, max_file_size_gb: int = MAX_FILE_SIZE_GB) -> list:
Expand Down Expand Up @@ -543,10 +532,7 @@ def quantize_model(
Tuple: Tuple containing quantized weights and config.
"""
quantized_config = copy.deepcopy(config)

nn.QuantizedLinear.quantize_module(
model, q_group_size, q_bits, linear_class_predicate=linear_class_predicate
)
nn.quantize(model, q_group_size, q_bits)
quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits}
quantized_weights = dict(tree_flatten(model.parameters()))

Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.9.0"
__version__ = "0.10.0"
Loading

0 comments on commit 2146bcd

Please sign in to comment.