diff --git a/examples/ggml-quant/ggmltensor.py b/examples/ggml-quant/ggmltensor.py new file mode 100644 index 0000000000..96b7bcdcf9 --- /dev/null +++ b/examples/ggml-quant/ggmltensor.py @@ -0,0 +1,414 @@ +import collections +from enum import Enum +import functools +import json +import operator +import pathlib +import struct + +import numpy +import torch + + +class GgufMetadataValueType(Enum): # // gguf_metadata_value_type (uint32_t enum) in llama.cpp + uint8 = 0 + int8 = 1 + uint16 = 2 + int16 = 3 + uint32 = 4 + int32 = 5 + float32 = 6 + bool = 7 + string = 8 + array = 9 + uint64 = 10 + int64 = 11 + float64 = 12 + + +class GgmlType(Enum): # uint32 + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + # 4, 5 are Q4_2/3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + IQ2_XXS = 16 + IQ2_XS = 17 + IQ3_XXS = 18 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = (22,) + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + + +GGML_BLOCK_SIZES = { + GgmlType.F32: 4, + GgmlType.Q4_0: 2 + 16, + GgmlType.Q8_0: 2 + 32, + GgmlType.Q2_K: 256 // 16 + 256 // 4 + 2 + 2, + GgmlType.Q3_K: 256 // 8 + 256 // 4 + 12 + 2, + GgmlType.Q4_K: 2 + 2 + 12 + 256 // 2, + GgmlType.Q5_K: 2 + 2 + 12 + 256 // 8 + 256 // 2, + GgmlType.Q6_K: 256 // 2 + 256 // 4 + 256 // 16 + 2, +} + +GGML_ELEMENTS_PER_BLOCK = { + GgmlType.F32: 1, + GgmlType.Q4_0: 32, + GgmlType.Q8_0: 32, + GgmlType.Q2_K: 256, + GgmlType.Q3_K: 256, + GgmlType.Q4_K: 256, + GgmlType.Q5_K: 256, + GgmlType.Q6_K: 256, +} + + +def compute_number_of_blocks(shape, ggml_type): + # todo: from the code, it looks like the padding might be per-row + numel = functools.reduce(operator.mul, shape, 1) + block_numel = GGML_ELEMENTS_PER_BLOCK[ggml_type] + num_blocks = (numel + block_numel - 1) // block_numel + return num_blocks + + +def read_gguf_string(f): + (length,) = struct.unpack("L", f.read(8)) + res = f.read(length).decode("utf8") + return res + + +def read_metadata_kv(f): + k = read_gguf_string(f) + (metadata_value_type,) = struct.unpack("I", f.read(4)) + metadata_value_type = GgufMetadataValueType(metadata_value_type) + + def read_value(f, typ): + if typ == GgufMetadataValueType.string: + v = read_gguf_string(f) + elif typ == GgufMetadataValueType.uint32: + (v,) = struct.unpack("I", f.read(4)) + v = numpy.uint32(v) + elif typ == GgufMetadataValueType.int32: + (v,) = struct.unpack("i", f.read(4)) + v = numpy.int32(v) + elif typ == GgufMetadataValueType.float32: + (v,) = struct.unpack("f", f.read(4)) + v = numpy.float32(v) + elif typ == GgufMetadataValueType.array: + (element_metadata_value_type,) = struct.unpack("> 2) & 0x3) << 4)).view(torch.int8) - 32 + q3 = ((block_q6_K_ql[:, :, 0, :] >> 4) | (((block_q6_K_qh[:, :, :] >> 4) & 0x3) << 4)).view(torch.int8) - 32 + q4 = ((block_q6_K_ql[:, :, 1, :] >> 4) | (((block_q6_K_qh[:, :, :] >> 6) & 0x3) << 4)).view(torch.int8) - 32 + y[:, :, 0, :] = block_q6_K_d * block_q6_K_scales[:, :, 0, :] * q1 + y[:, :, 1, :] = block_q6_K_d * block_q6_K_scales[:, :, 1, :] * q2 + y[:, :, 2, :] = block_q6_K_d * block_q6_K_scales[:, :, 2, :] * q3 + y[:, :, 3, :] = block_q6_K_d * block_q6_K_scales[:, :, 3, :] * q4 + numel = functools.reduce(operator.mul, shape, 1) + # in ggml the first is the fastest moving dimension + y = y.reshape(-1)[:numel].reshape(*shape[::-1]).permute(*list(range(len(shape) - 1, -1, -1))) + return y + + +# define QK4_0 32 +# typedef struct { +# ggml_half d; // delta +# uint8_t qs[QK4_0 / 2]; // nibbles / quants +# } block_q4_0; +@register_ggml_dequantizer(GgmlType.Q4_0) +def dequantize_q4_0(block_data, shape, dtype=torch.float32): + num_blocks, num_block_bytes = block_data.shape + assert num_block_bytes == GGML_BLOCK_SIZES[GgmlType.Q4_0] + block_q4_0_d = block_data[:, :2].view(torch.float16).to(torch.float32).reshape(num_blocks, 1) + block_q4_0_qs = block_data[:, 2:] + num_blocks, _ = block_q4_0_qs.shape + y = torch.empty((num_blocks, 2, 16), dtype=dtype, device=block_data.device) + + q0 = block_q4_0_qs & 0xF + q0 = q0.view(torch.int8) - 8 + q1 = (block_q4_0_qs >> 4).view(torch.int8) - 8 + y[:, 0, :] = block_q4_0_d * q0 + y[:, 1, :] = block_q4_0_d * q1 + numel = functools.reduce(operator.mul, shape, 1) + # in ggml the first is the fastest moving dimension + y = y.reshape(-1)[:numel].reshape(*shape[::-1]).permute(*list(range(len(shape) - 1, -1, -1))) + return y + + +@register_ggml_dequantizer(GgmlType.F32) +@register_ggml_dequantizer(GgmlType.F16) +def dequantize_noop(block_data, shape, dtype=torch.float32): + assert block_data.shape == shape + return block_data.to(dtype) + + +def dequantize(qw, typ, shape, dtype=torch.float32): + dequantizer = GGML_DEQUANTIZERS.get(typ) + if typ is None: + raise NotImplementedError("Cannot decode {typ}") + return dequantizer(qw, shape, dtype) + + +def cat_quantized(tensors, infos, dim): + assert isinstance(tensors, collections.abc.Sequence) and len(tensors) > 0 + assert isinstance(infos, collections.abc.Sequence) and len(infos) == len(tensors) + typ_0, shape_0 = infos[0] + assert typ_0 not in (GgmlType.F16, GgmlType.F32), "This only works for quantized" + total_dim = len(shape_0) + if dim < 0: + dim += total_dim + assert 0 <= dim < total_dim + concat_dim_size = 0 + reshaped_tensors = [] + for (typ_i, shape_i), tensor_i in zip(infos, tensors): + assert typ_i == typ_0 and len(shape_i) == total_dim, "concatenated tensors must have same type and dimension" + assert ( + shape_i[:dim] == shape_0[:dim] and shape_i[dim + 1 :] == shape_0[dim + 1 :] + ), "shapes must match except in the concat dimension" + concat_dim_size += shape_i[dim] + numel_in_back = functools.reduce(operator.mul, shape_i[: dim + 1]) + assert numel_in_back % GGML_ELEMENTS_PER_BLOCK[typ_0] == 0 + blocks_in_back = numel_in_back // GGML_ELEMENTS_PER_BLOCK[typ_0] + reshaped_tensors.append(tensor_i.reshape(-1, blocks_in_back, GGML_BLOCK_SIZES[typ_0])) + new_tensor = torch.cat(reshaped_tensors, dim=1).view(-1, GGML_BLOCK_SIZES[typ_0]) + new_info = (typ_0, (*shape_0[:dim], concat_dim_size, *shape_0[dim + 1 :])) + return new_tensor, new_info + + +def merge_attention_weights(qq, info_q, qk, info_k, qv, info_v): + typ_q, shape_q = info_q + typ_k, shape_k = info_k + typ_v, shape_v = info_v + + num_blocks = qq.shape[0] + assert qq.shape[1] == GGML_BLOCK_SIZES[typ_q] + assert shape_q[0] % GGML_ELEMENTS_PER_BLOCK[typ_q] == 0 + qq_swizzled = ( + qq.view(-1, 64, 2, shape_q[0] // GGML_ELEMENTS_PER_BLOCK[typ_q], qq.shape[1]).transpose(1, 2).reshape(*qq.shape) + ) + qk_swizzled = ( + qk.view(-1, 64, 2, shape_k[0] // GGML_ELEMENTS_PER_BLOCK[typ_k], qk.shape[1]).transpose(1, 2).reshape(*qk.shape) + ) + + dqq2 = dequantize(qq_swizzled, typ_q, shape_q) + dqk2 = dequantize(qk_swizzled, typ_k, shape_k) + + assert shape_k[0] % GGML_ELEMENTS_PER_BLOCK[typ_k] == 0 + assert shape_v[0] % GGML_ELEMENTS_PER_BLOCK[typ_v] == 0 + + q_all, (typ_all, shape_all) = cat_quantized( + [qq_swizzled, qk_swizzled, qv], + [ + (typ_q, (shape_q[0], shape_q[1] // 8, 8)), + (typ_k, (shape_k[0], shape_k[1] // 8, 8)), + (typ_v, (shape_v[0], shape_v[1] // 8, 8)), + ], + dim=1, + ) + shape_all = (shape_all[0], shape_all[1] * shape_all[2]) + return q_all, (typ_all, shape_all) + + +class GgmlDataReader: + def __init__(self, model_file_name): + model_file_path = pathlib.Path(model_file_name).expanduser() + data = json.load(open(model_file_path)) + (model_data,) = [d for d in data["layers"] if d["mediaType"] == "application/vnd.ollama.image.model"] + model_data_blob = model_data["digest"].replace(":", "-") + + p = model_file_path + while p.name != ".ollama": + p = p.parent + model_blob_path = p / "models" / "blobs" / model_data_blob + + self.f = open(model_blob_path, "rb") + metadata, tensor_infos = read_gguf_header(self.f) + + self.tensor_infos = {name: (typ, shape, offset) for name, shape, typ, offset in tensor_infos} + + def read_tensor(self, ggml_name): + typ, shape, offset = self.tensor_infos[ggml_name] + + self.f.seek(offset) + num_block_bytes = GGML_BLOCK_SIZES[typ] + num_blocks = compute_number_of_blocks(shape, typ) + numel = functools.reduce(operator.mul, shape, 1) + + block_data_raw = self.f.read(num_blocks * num_block_bytes) + block_data = numpy.frombuffer(block_data_raw, dtype=numpy.uint8).reshape(num_blocks, num_block_bytes) + if typ == GgmlType.F32: + # in ggml the first is the fastest moving dimension + block_data = ( + block_data.view(numpy.float32) + .reshape(-1)[:numel] + .reshape(*shape[::-1]) + .transpose(*list(range(len(shape) - 1, -1, -1))) + ) + elif typ == GgmlType.F16: + block_data = ( + block_data.view(numpy.float16) + .reshape(-1)[:numel] + .reshape(*shape[::-1]) + .transpose(*list(range(len(shape) - 1, -1, -1))) + ) + import warnings + + with warnings.catch_warnings(action="ignore", category=UserWarning): # ignore read-only warning + return torch.from_numpy(block_data), (typ, shape) + + def close(self): + self.f.close() + + def get_parameter(self, name): + ggml_name = ( + name.replace("transformer.wte.weight", "token_embd.weight") + .replace("transformer.h.", "blk.") + .replace(".norm_1.weight", ".attn_norm.weight") + .replace(".attn.attn.weight", ".attn_q.weight") + .replace(".attn.proj.weight", ".attn_output.weight") + .replace(".norm_2.weight", ".ffn_norm.weight") + .replace(".mlp.fc_1.weight", ".ffn_gate.weight") + .replace(".mlp.fc_2.weight", ".ffn_up.weight") + .replace(".mlp.proj.weight", ".ffn_down.weight") + .replace("transformer.ln_f.weight", "output_norm.weight") + .replace("lm_head.weight", "output.weight") + ) + if not name.endswith("attn.attn.weight"): + q, (typ, shape) = self.read_tensor(ggml_name) + else: + ggml_name_key = ggml_name.replace("attn_q", "attn_k") + ggml_name_value = ggml_name.replace("attn_q", "attn_v") + q, (typ, shape) = merge_attention_weights( + *self.read_tensor(ggml_name), *self.read_tensor(ggml_name_key), *self.read_tensor(ggml_name_value) + ) + return q, (typ, shape) diff --git a/examples/ggml-quant/thunder_ggmlquant.ipynb b/examples/ggml-quant/thunder_ggmlquant.ipynb new file mode 100644 index 0000000000..103902fbce --- /dev/null +++ b/examples/ggml-quant/thunder_ggmlquant.ipynb @@ -0,0 +1,748 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2db72f9", + "metadata": { + "hideCode": false, + "hidePrompt": false + }, + "source": [ + "# Loading GGML / Ollama weights into LitGPT" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d3b8246e", + "metadata": { + "hideCode": false, + "hidePrompt": false + }, + "outputs": [], + "source": [ + "import thunder\n", + "import torch\n", + "import ggmltensor\n", + "\n", + "def load_ggml_weights(model, fn): \n", + " ggml_quant = ggmltensor.GgmlDataReader(fn)\n", + "\n", + " for n, p in model.named_parameters(): \n", + " qw, (typ, shape) = ggml_quant.get_parameter(n)\n", + " with torch.no_grad():\n", + " w = ggmltensor.dequantize(qw, typ, shape, dtype=p.dtype).to(p.device)\n", + " p.copy_(w.t())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "51d82e38-ddcc-443b-b846-1d607404d7b7", + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n", + "\n", + "import sys\n", + "import time\n", + "from pathlib import Path\n", + "from typing import Any, Literal, Optional\n", + "\n", + "import lightning as L\n", + "import torch\n", + "import torch._dynamo.config\n", + "import torch._inductor.config\n", + "#from lightning.fabric.plugins import BitsandbytesPrecision\n", + "\n", + "from litgpt import GPT, Config, PromptStyle, Tokenizer\n", + "from litgpt.prompts import has_prompt_style, load_prompt_style\n", + "from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint\n", + "\n", + "\n", + "\n", + "def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:\n", + " if torch._dynamo.is_compiling():\n", + " # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly\n", + " distribution = torch.empty_like(probs).exponential_(1)\n", + " return torch.argmax(probs / distribution, dim=-1, keepdim=True)\n", + " return torch.multinomial(probs, num_samples=1)\n", + "\n", + "\n", + "def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:\n", + " sorted_logits, sorted_indices = torch.sort(logits, descending=False)\n", + " cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n", + " # Example:\n", + " # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]\n", + " # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7\n", + " sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n", + " # Keep at least 1 token always to prevent the case where no token is selected\n", + " # In this case the most probable one is always kept\n", + " sorted_indices_to_remove[-1:] = 0\n", + " indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)\n", + " logits = logits.masked_fill(indices_to_remove, float(\"-inf\"))\n", + " return logits\n", + "\n", + "\n", + "def sample(\n", + " logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0\n", + ") -> torch.Tensor:\n", + " if top_p < 0.0 or top_p > 1.0:\n", + " raise ValueError(f\"top_p must be in [0, 1], got {top_p}\")\n", + " logits = logits[0, -1]\n", + " # optionally crop the logits to only the top k options\n", + " if top_k is not None:\n", + " v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n", + " # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n", + " logits = torch.full_like(logits, float(\"-inf\")).scatter_(-1, i, v)\n", + " # optionally scale the logits and sample from a probability distribution\n", + " if temperature > 0.0 or top_p > 0.0:\n", + " if temperature > 0.0:\n", + " logits = logits / temperature\n", + " # optionally crop the logits to smallest set of logits with a cumulative probability above top_p\n", + " if top_p < 1.0:\n", + " logits = sample_top_p(logits, top_p)\n", + " probs = torch.nn.functional.softmax(logits, dim=-1)\n", + " return multinomial_num_samples_1(probs)\n", + " return torch.argmax(logits, dim=-1, keepdim=True)\n", + "\n", + "\n", + "def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:\n", + " logits = model(x, input_pos)\n", + " next = sample(logits, **kwargs)\n", + " return next.to(dtype=x.dtype)\n", + "\n", + "\n", + "@torch.inference_mode()\n", + "def generate(\n", + " model: GPT,\n", + " prompt: torch.Tensor,\n", + " max_returned_tokens: int,\n", + " *,\n", + " temperature: float = 1.0,\n", + " top_k: Optional[int] = None,\n", + " top_p: float = 1.0,\n", + " eos_id: Optional[int] = None,\n", + ") -> torch.Tensor:\n", + " \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n", + "\n", + " The implementation of this function is modified from A. Karpathy's nanoGPT.\n", + "\n", + " Args:\n", + " model: The model to use.\n", + " prompt: Tensor of shape (T) with indices of the prompt sequence.\n", + " max_returned_tokens: The maximum number of tokens to return (given plus generated).\n", + " temperature: Scales the predicted logits by 1 / temperature.\n", + " top_k: If specified, only sample among the tokens with the k highest probabilities.\n", + " top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n", + " In top-p sampling, the next token is sampled from the highest probability tokens\n", + " whose cumulative probability exceeds the threshold `top_p`. When specified,\n", + " it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n", + " to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n", + " It can be used in conjunction with `top_k` and `temperature` with the following order\n", + " of application:\n", + "\n", + " 1. `top_k` sampling\n", + " 2. `temperature` scaling\n", + " 3. `top_p` sampling\n", + "\n", + " For more details, see https://arxiv.org/abs/1904.09751\n", + " or https://huyenchip.com/2024/01/16/sampling.html#top_p\n", + " eos_id: If specified, stop generating any more token once the token is triggered.\n", + " \"\"\"\n", + " T = prompt.size(0)\n", + " assert max_returned_tokens > T\n", + " if model.max_seq_length < max_returned_tokens - 1:\n", + " # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a\n", + " # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do\n", + " # not support it to avoid negatively impacting the overall speed\n", + " raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n", + "\n", + " device = prompt.device\n", + " tokens = [prompt]\n", + " input_pos = torch.tensor([T], device=device)\n", + " token = next_token(\n", + " model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p\n", + " ).clone()\n", + " tokens.append(token)\n", + " for _ in range(2, max_returned_tokens - T + 1):\n", + " token = next_token(\n", + " model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p\n", + " ).clone()\n", + " tokens.append(token)\n", + " if token == eos_id:\n", + " break\n", + " input_pos = input_pos.add_(1)\n", + " return torch.cat(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "39943398-afaa-40bb-8de1-39d021914245", + "metadata": {}, + "outputs": [], + "source": [ + "with torch.inference_mode():\n", + " prompt: str = \"What food do llamas eat?\"\n", + " num_samples: int = 1\n", + " max_new_tokens: int = 256\n", + " top_k: Optional[int] = 50\n", + " top_p: float = 1.0\n", + " temperature: float = 0.8\n", + " checkpoint_dir: Path = Path(\"/home/tv/data/firma/grid/thunder/litgpt/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/\")\n", + " quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = \"bnb.nf4\"\n", + " precision: Optional[str] = \"bf16-true\"\n", + " compile: bool = False\n", + "# litgpt generate base --quantize bnb.nf4 --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256\n", + "\n", + " \"\"\"Generates text samples based on a pre-trained model and tokenizer.\n", + "\n", + " Args:\n", + " prompt: The prompt string to use for generating the samples.\n", + " num_samples: The number of text samples to generate.\n", + " max_new_tokens: The number of generation steps to take.\n", + " top_k: The number of top most probable tokens to consider in the sampling process.\n", + " top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n", + " In top-p sampling, the next token is sampled from the highest probability tokens\n", + " whose cumulative probability exceeds the threshold `top_p`. When specified,\n", + " it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n", + " to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n", + " It can be used in conjunction with `top_k` and `temperature` with the following order\n", + " of application:\n", + "\n", + " 1. `top_k` sampling\n", + " 2. `temperature` scaling\n", + " 3. `top_p` sampling\n", + "\n", + " For more details, see https://arxiv.org/abs/1904.09751\n", + " or https://huyenchip.com/2024/01/16/sampling.html#top_p\n", + " temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n", + " samples.\n", + " checkpoint_dir: The checkpoint directory to load.\n", + " quantize: Whether to quantize the model and using which method:\n", + " - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n", + " - bnb.int8: 8-bit quantization from bitsandbytes\n", + " for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n", + " precision: Indicates the Fabric precision setting to use.\n", + " compile: Whether to compile the model.\n", + " \"\"\"\n", + " precision = precision or get_default_supported_precision(training=False)\n", + "\n", + " #plugins = BitsandbytesPrecision(mode='nf4', dtype=torch.bfloat16)\n", + " \n", + " precision = 'bf16-true'\n", + "\n", + " fabric = L.Fabric(devices=1, precision=precision) #, plugins=plugins)\n", + "\n", + " check_valid_checkpoint_dir(checkpoint_dir)\n", + " config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n", + "\n", + " checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n", + "\n", + " tokenizer = Tokenizer(checkpoint_dir)\n", + " prompt_style = (\n", + " load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n", + " )\n", + "\n", + " prompt = prompt_style.apply(prompt)\n", + " encoded = tokenizer.encode(prompt, device=fabric.device)\n", + " prompt_length = encoded.size(0)\n", + " max_returned_tokens = prompt_length + max_new_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7afe7f82-d468-4570-83f5-9bfc4817183c", + "metadata": {}, + "outputs": [], + "source": [ + "if 0:\n", + " fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n", + " t0 = time.perf_counter()\n", + " with fabric.init_module(empty_init=True):\n", + " model = GPT(config)\n", + " fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n", + " with fabric.init_tensor():\n", + " # set the max_seq_length to limit the memory usage to what we need\n", + " model.max_seq_length = max_returned_tokens\n", + " # enable the kv cache\n", + " model.set_kv_cache(batch_size=1)\n", + " model.eval()\n", + "\n", + " if compile:\n", + " torch._dynamo.config.automatic_dynamic_shapes = True\n", + " torch._inductor.config.triton.unique_kernel_names = True\n", + " torch._inductor.config.coordinate_descent_tuning = True\n", + " global next_token\n", + " next_token = torch.compile(next_token, mode=\"reduce-overhead\")\n", + "\n", + " model = fabric.setup_module(model)\n", + " \n", + " ggml_fn = '~/.ollama/models/manifests/registry.ollama.ai/library/llama3/latest'\n", + "\n", + " t0 = time.perf_counter()\n", + " #load_checkpoint(fabric, model, checkpoint_path)\n", + " load_ggml_weights(model._original_module, ggml_fn)\n", + "\n", + " fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "483cc7da-4db8-4953-b5c7-657b709cc8ce", + "metadata": {}, + "outputs": [], + "source": [ + "if 0:\n", + " with torch.inference_mode():\n", + " L.seed_everything(1234)\n", + " for i in range(num_samples):\n", + " t0 = time.perf_counter()\n", + " y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)\n", + " t = time.perf_counter() - t0\n", + " for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", + " fabric.print(tokenizer.decode(y))\n", + " tokens_generated = y.size(0) - prompt_length\n", + " fabric.print(\n", + " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", + " )\n", + " if fabric.device.type == \"cuda\":\n", + " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" + ] + }, + { + "cell_type": "markdown", + "id": "8705fb23-8b24-488f-ad85-72e4cdd98c25", + "metadata": {}, + "source": [ + "# Thunder transform" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a76d44e0-dd65-4e6c-a747-96f2ef1e33fb", + "metadata": {}, + "outputs": [], + "source": [ + "from collections.abc import Sequence\n", + "\n", + "import thunder\n", + "from thunder.core.transform_common import EarlyTransform\n", + "from thunder.core import utils\n", + "from thunder.core import prims\n", + "import torch\n", + "\n", + "from thunder.transforms.utils import (\n", + " get_orig_and_thunder_module_proxies_from_prologue,\n", + " get_checks,\n", + " add_trace_output,\n", + ")\n", + "\n", + "import ggmltensor\n", + "\n", + "ggmlquant_executor = thunder.extend.OperatorExecutor(\"quant_ggml\", version=0.1)\n", + "\n", + "def ggmlquant_matmul_meta(x, qweight, ggmltype: int, shape):\n", + " assert isinstance(shape, Sequence) and len(shape) == 2\n", + " assert x.shape[-1] == shape[1], f\"{x.shape=}, rhs {shape=}\"\n", + " return thunder.TensorProxy(like=x, shape=(*x.shape[:-1], shape[0]))\n", + "\n", + "\n", + "def ggmlquant_matmul_impl(x, qweight, ggmltype: int, shape):\n", + " w = ggmltensor.dequantize(qweight, ggmltensor.GgmlType(ggmltype), shape, dtype=x.dtype)\n", + " return x @ w\n", + "\n", + "def ggmlquant_embed_meta(x, qweight, ggmltype: int, shape):\n", + " assert isinstance(shape, Sequence) and len(shape) == 2\n", + " # checks for mul\n", + " return thunder.TensorProxy(like=x, shape=(*x.shape, shape[1]))\n", + "\n", + "def ggmlquant_embed_impl(x, qweight, ggmltype: int, shape):\n", + " w = ggmltensor.dequantize(qweight, ggmltensor.GgmlType(ggmltype), shape, dtype=torch.bfloat16)\n", + " return torch.nn.functional.embedding(x, w.t()) \n", + "\n", + "ggmlquant_matmul = ggmlquant_executor.register_operator(\n", + " \"ggmlquant_matmul\", meta=ggmlquant_matmul_meta, fn=ggmlquant_matmul_impl\n", + ")\n", + "\n", + "ggmlquant_embed = ggmlquant_executor.register_operator(\n", + " \"ggmlquant_embed\", meta=ggmlquant_embed_meta, fn=ggmlquant_embed_impl\n", + ")\n", + "\n", + "\n", + "\n", + "class GGMLQuantTransform(EarlyTransform):\n", + " def __init__(self, model_file_name, device):\n", + " self.quant_states = {}\n", + " self.quantized_submodule_names = set()\n", + " self.device = device\n", + " self.model_file_name = model_file_name\n", + "\n", + " def transform_module(self, model: thunder.ThunderModule):\n", + " ggml_quant = ggmltensor.GgmlDataReader(self.model_file_name)\n", + " self.thunder_module = model\n", + "\n", + " def convert_layer_with_weight(tm, name):\n", + " self.quantized_submodule_names.add(name)\n", + " weight_name = f\"{name}.weight\"\n", + " w = tm.get_parameter(weight_name)\n", + " qw, (typ, shape) = ggml_quant.get_parameter(weight_name)\n", + " tm._overrides_parameters[weight_name] = qw.to(self.device)\n", + " if not qw.is_floating_point():\n", + " self.quant_states[weight_name] = {\"typ\": typ, \"shape\": shape}\n", + "\n", + " for n, submodule in model._model.named_modules():\n", + " if hasattr(submodule, \"weight\"):\n", + " convert_layer_with_weight(model, n)\n", + " ggml_quant.close()\n", + "\n", + " def transform_state_dict_for_submodule(self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict):\n", + " raise NotImplementedError(\"load weights ...\")\n", + "\n", + " def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):\n", + " tm = self.thunder_module\n", + " from thunder.core.trace import tracectx\n", + "\n", + " checks = get_checks(prologue_trace)\n", + "\n", + " compute_producers, compute_consumers = utils.producers_and_consumers(computation_trace)\n", + "\n", + " proglogue_to_compute_outputs = prologue_trace.output[0]\n", + "\n", + " output_idxes = {id(o): i for i, o in enumerate(proglogue_to_compute_outputs)}\n", + "\n", + " computation_trace.push_scope([])\n", + " quantized_proxies: dict[int, str] = {} # id -> name\n", + "\n", + " for n, qs in self.quant_states.items():\n", + " param = tm.get_parameter(n)\n", + " check, get_param = checks[n]\n", + " quantized_proxies[id(get_param.output)] = n\n", + " # check has args: tensor, shape, device, dtype, requires_grad\n", + " proxy, _, _, _, requires_grad = check.args\n", + " thunder_device = thunder.devices.to_device(param.device)\n", + " thunder_device_str = thunder_device.device_str()\n", + " check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)\n", + " for n, param in tm.named_parameters():\n", + " if n not in self.quant_states:\n", + " check, get_param = checks[n]\n", + " proxy, _, _, _, requires_grad = check.args\n", + " thunder_device = thunder.devices.to_device(param.device)\n", + " thunder_device_str = thunder_device.device_str() \n", + " check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)\n", + "\n", + " new_computation_trace = thunder.core.trace.from_trace(computation_trace)\n", + "\n", + " proxies_to_replace = {}\n", + " for bsym in computation_trace.bound_symbols:\n", + " if bsym.sym == thunder.torch.linear and id(bsym.args[1]) in quantized_proxies:\n", + " assert len(bsym.args) == 3 # torch.linear(input, weight, bias)\n", + " assert bsym.args[2] is None\n", + " n = quantized_proxies[id(bsym.args[1])]\n", + " qs = self.quant_states[n]\n", + " # signature of the new symbol:\n", + " # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)\n", + " new_args = (\n", + " *bsym.args[:2],\n", + " qs[\"typ\"].value, # integer value\n", + " qs[\"shape\"],\n", + " )\n", + " mm_bsym = bsym.from_bsym(\n", + " sym=ggmlquant_matmul,\n", + " subsymbols=[],\n", + " args=new_args,\n", + " )\n", + "\n", + " new_computation_trace.bound_symbols.append(mm_bsym)\n", + " # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n", + " # preserve the \"meta\"-info like source location, header, etc.\n", + " # TODO: switch to a better solution when it is there\n", + " ggmlquant_matmul._bind_postprocess(mm_bsym)\n", + " elif bsym.sym == thunder.torch.embedding and id(bsym.args[1]) in quantized_proxies:\n", + " assert len(bsym.args) == 7 # torch.linear(input, weight, bias)\n", + " assert bsym.args[2] is None and bsym.args[3] is None\n", + " assert bsym.args[5] is False and bsym.args[6] is False\n", + " n = quantized_proxies[id(bsym.args[1])]\n", + " qs = self.quant_states[n]\n", + " new_args = (\n", + " *bsym.args[:2],\n", + " qs[\"typ\"].value, # integer value\n", + " qs[\"shape\"],\n", + " )\n", + " emb_bsym = bsym.from_bsym(\n", + " sym=ggmlquant_embed,\n", + " subsymbols=[],\n", + " args=new_args,\n", + " )\n", + "\n", + " new_computation_trace.bound_symbols.append(emb_bsym)\n", + " # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n", + " # preserve the \"meta\"-info like source location, header, etc.\n", + " # TODO: switch to a better solution when it is there\n", + " ggmlquant_embed._bind_postprocess(emb_bsym)\n", + " else:\n", + " new_computation_trace.bound_symbols.append(bsym.from_bsym())\n", + "\n", + " new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance(\"quant pass\"))\n", + " return prologue_trace, new_computation_trace, epilogue_trace\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c08de032-8db4-49ce-9382-868263d00d4d", + "metadata": {}, + "outputs": [], + "source": [ + "import thunder, torch" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9c070756-4857-4741-9658-ee83b9dceeaf", + "metadata": {}, + "outputs": [], + "source": [ + "import thunder.tests.litgpt_model" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "250b5edd-31d0-456c-8bf1-96fc73e4e5af", + "metadata": {}, + "outputs": [], + "source": [ + "import litgpt\n", + "with torch.device(\"meta\"):\n", + " m = thunder.tests.litgpt_model.GPT.from_name('Llama-3-8B-Instruct')\n", + " m.requires_grad_(False)\n", + " #del m.transformer.h[2:]\n", + "# enable the kv cache\n", + "device = \"cuda\"\n", + "with torch.device(device):\n", + " m.max_seq_length = max_returned_tokens\n", + " m.set_kv_cache(batch_size=1)\n", + "m.cos, m.sin = m.rope_cache(device=torch.device('cuda'))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "39264594-703f-4393-ac2d-6c88e50b5764", + "metadata": {}, + "outputs": [], + "source": [ + "model_file_name = '~/.ollama/models/manifests/registry.ollama.ai/library/llama3/latest'\n", + "\n", + "quant_transform = GGMLQuantTransform(model_file_name, torch.device(\"cuda\"))\n", + "tm = thunder.jit(m, early_transforms=[quant_transform])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a87b6c62-afa7-4537-90ac-de99b2a7c4af", + "metadata": {}, + "outputs": [], + "source": [ + "#a = torch.randint(1, 100, (1, 64), device=\"cuda\")\n", + "#tm(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a7a0c755-77a3-44a4-8f11-0b3d308c75f2", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "model = tm\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2d8666d-387f-4676-a01b-04c8567a9c5a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 1234\n" + ] + } + ], + "source": [ + "for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", + "num_samples = 1\n", + "with torch.inference_mode():\n", + " L.seed_everything(1234)\n", + " for i in range(num_samples):\n", + " t0 = time.perf_counter()\n", + " y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)\n", + " t = time.perf_counter() - t0\n", + " for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", + " fabric.print(tokenizer.decode(y))\n", + " tokens_generated = y.size(0) - prompt_length\n", + " fabric.print(\n", + " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", + " )\n", + " if fabric.device.type == \"cuda\":\n", + " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92cdfcae-4fc1-434e-acfc-27f9d5c53942", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65ad6791-a375-4cbe-b74b-2bf08b202d0e", + "metadata": {}, + "outputs": [], + "source": [ + "model.transformer.h[0].attn.kv_cache.v.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "278af111-7b12-4579-8efb-f8a03582c355", + "metadata": {}, + "outputs": [], + "source": [ + "model._forward_module.max_seq_length" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b33c1b1c-45b4-42d6-b363-9adcbeb8b017", + "metadata": {}, + "outputs": [], + "source": [ + "model._forward_module.config.n_head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96c7e96a-8794-4219-bd9f-f4dc0e18fde0", + "metadata": {}, + "outputs": [], + "source": [ + "model._model.mask_cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86e8c0c4-48cc-44c0-ad21-77a870f14dcb", + "metadata": {}, + "outputs": [], + "source": [ + "m2 = thunder.tests.litgpt_model.OverridenKVCache((1, 32, 286, 128), (1, 32, 286, 128), device=torch.device(\"cuda\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cd6e24a-3c16-4d48-94e3-c1da897ea04f", + "metadata": {}, + "outputs": [], + "source": [ + "tm2 = thunder.jit(m2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1a4cd5f-a4d2-458d-b51c-7cb679049655", + "metadata": {}, + "outputs": [], + "source": [ + "input_pos = torch.tensor([2], device=\"cuda\")\n", + "k, v = torch.randn(2, 1, 32, 1, 128, device=\"cuda\")\n", + "tm2(input_pos, k, v)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a80452-8852-45ea-bad9-04048336d0ef", + "metadata": {}, + "outputs": [], + "source": [ + "m2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a90b8686-678b-4982-abc9-6bff19abfc45", + "metadata": {}, + "outputs": [], + "source": [ + "torch.Tensor.index_add" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97917740-cd76-4471-b1fd-284c3b159c65", + "metadata": {}, + "outputs": [], + "source": [ + "?? m.transformer.h[0].attn.kv_cache.forward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3daf3945-91e3-4836-806f-ef8dd972515f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_code_all_hidden": false, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 51fd803c40..dce0c3b9f5 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -3636,7 +3636,7 @@ def linear_meta(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> Ten utils.check(isinstance(w, TensorProxy), lambda: f"w={w} was not a TensorProxy!") # Checks that required arguments are on the same device - utils.check(a.device == w.device, lambda: f"Expected a.device={a.device} and w.device={w.device} to be the same!") + utils.check_same_device(a, w) # Acquires the computation dtype and checks that a and w have the same dtype dtype = a.dtype @@ -3696,7 +3696,7 @@ def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy: if a.ndim < 1 or b.ndim < 1: raise NotImplementedError - utils.check(a.device == b.device, lambda: f"Expected a.device={a.device} and b.device={b.device} to be the same") + utils.check_same_device(a, b) utils.check( dtypes.are_same_dtypes(a, b), lambda: f"Expected a.dtype={a.dtype} and b.dtype={b.dtype} to be the same" diff --git a/thunder/core/utils.py b/thunder/core/utils.py index a094fc6548..271dcdf301 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -666,7 +666,8 @@ def _reify(x): # TODO: improve device handling by canonicalizing devices and expressing them per langctx # TODO: should the comparison between devices be ==? def check_same_device(*args): - devices = tuple(x.device for x in args if isinstance(x, TensorProxyInterface)) + devices = tuple(x.device for x in args if isinstance(x, TensorProxyInterface) and x.device.type != "meta") + if len(devices) > 1: device = devices[0] for otherdevice in devices[1:]: diff --git a/thunder/tests/litgpt_model.py b/thunder/tests/litgpt_model.py index 0926715952..6dbd35f2ef 100644 --- a/thunder/tests/litgpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -118,6 +118,43 @@ name_to_config = {config["name"]: config for config in configs} +class OverridenKVCache(nn.Module): + def __init__( + self, + k_shape: tuple[int, int, int, int], + v_shape: tuple[int, int, int, int], + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> None: + super().__init__() + self.register_buffer("k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False) + self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) + + def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + # NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`. + is_compiling = ( + torch.compiler.is_compiling if hasattr(torch.compiler, "is_compiling") else torch._dynamo.is_compiling + ) + if is_compiling(): + # inductor doesn't support `index_add` with bfloat16 + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) + return k, v + # See issue: "Support more indexing operators (index_copy and index_add)" + k = self.k = self.k.index_copy_(2, input_pos, k) + v = self.v = self.v.index_copy_(2, input_pos, v) + # THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro) + return k, v + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.k) + torch.nn.init.zeros_(self.v) + + import litgpt # add the testing configurations