From 6d6ad92d5046bf3ba0c2e4c8524869b48e6fd4ca Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 11 Jul 2024 07:10:24 +0200 Subject: [PATCH 1/4] ggml quantization in thunder v0 --- examples/ggml-quant/ggmltensor.py | 414 +++++++++++++ examples/ggml-quant/thunder_ggmlquant.ipynb | 641 ++++++++++++++++++++ thunder/core/utils.py | 3 +- 3 files changed, 1057 insertions(+), 1 deletion(-) create mode 100644 examples/ggml-quant/ggmltensor.py create mode 100644 examples/ggml-quant/thunder_ggmlquant.ipynb diff --git a/examples/ggml-quant/ggmltensor.py b/examples/ggml-quant/ggmltensor.py new file mode 100644 index 0000000000..96b7bcdcf9 --- /dev/null +++ b/examples/ggml-quant/ggmltensor.py @@ -0,0 +1,414 @@ +import collections +from enum import Enum +import functools +import json +import operator +import pathlib +import struct + +import numpy +import torch + + +class GgufMetadataValueType(Enum): # // gguf_metadata_value_type (uint32_t enum) in llama.cpp + uint8 = 0 + int8 = 1 + uint16 = 2 + int16 = 3 + uint32 = 4 + int32 = 5 + float32 = 6 + bool = 7 + string = 8 + array = 9 + uint64 = 10 + int64 = 11 + float64 = 12 + + +class GgmlType(Enum): # uint32 + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + # 4, 5 are Q4_2/3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + IQ2_XXS = 16 + IQ2_XS = 17 + IQ3_XXS = 18 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = (22,) + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + + +GGML_BLOCK_SIZES = { + GgmlType.F32: 4, + GgmlType.Q4_0: 2 + 16, + GgmlType.Q8_0: 2 + 32, + GgmlType.Q2_K: 256 // 16 + 256 // 4 + 2 + 2, + GgmlType.Q3_K: 256 // 8 + 256 // 4 + 12 + 2, + GgmlType.Q4_K: 2 + 2 + 12 + 256 // 2, + GgmlType.Q5_K: 2 + 2 + 12 + 256 // 8 + 256 // 2, + GgmlType.Q6_K: 256 // 2 + 256 // 4 + 256 // 16 + 2, +} + +GGML_ELEMENTS_PER_BLOCK = { + GgmlType.F32: 1, + GgmlType.Q4_0: 32, + GgmlType.Q8_0: 32, + GgmlType.Q2_K: 256, + GgmlType.Q3_K: 256, + GgmlType.Q4_K: 256, + GgmlType.Q5_K: 256, + GgmlType.Q6_K: 256, +} + + +def compute_number_of_blocks(shape, ggml_type): + # todo: from the code, it looks like the padding might be per-row + numel = functools.reduce(operator.mul, shape, 1) + block_numel = GGML_ELEMENTS_PER_BLOCK[ggml_type] + num_blocks = (numel + block_numel - 1) // block_numel + return num_blocks + + +def read_gguf_string(f): + (length,) = struct.unpack("L", f.read(8)) + res = f.read(length).decode("utf8") + return res + + +def read_metadata_kv(f): + k = read_gguf_string(f) + (metadata_value_type,) = struct.unpack("I", f.read(4)) + metadata_value_type = GgufMetadataValueType(metadata_value_type) + + def read_value(f, typ): + if typ == GgufMetadataValueType.string: + v = read_gguf_string(f) + elif typ == GgufMetadataValueType.uint32: + (v,) = struct.unpack("I", f.read(4)) + v = numpy.uint32(v) + elif typ == GgufMetadataValueType.int32: + (v,) = struct.unpack("i", f.read(4)) + v = numpy.int32(v) + elif typ == GgufMetadataValueType.float32: + (v,) = struct.unpack("f", f.read(4)) + v = numpy.float32(v) + elif typ == GgufMetadataValueType.array: + (element_metadata_value_type,) = struct.unpack("> 2) & 0x3) << 4)).view(torch.int8) - 32 + q3 = ((block_q6_K_ql[:, :, 0, :] >> 4) | (((block_q6_K_qh[:, :, :] >> 4) & 0x3) << 4)).view(torch.int8) - 32 + q4 = ((block_q6_K_ql[:, :, 1, :] >> 4) | (((block_q6_K_qh[:, :, :] >> 6) & 0x3) << 4)).view(torch.int8) - 32 + y[:, :, 0, :] = block_q6_K_d * block_q6_K_scales[:, :, 0, :] * q1 + y[:, :, 1, :] = block_q6_K_d * block_q6_K_scales[:, :, 1, :] * q2 + y[:, :, 2, :] = block_q6_K_d * block_q6_K_scales[:, :, 2, :] * q3 + y[:, :, 3, :] = block_q6_K_d * block_q6_K_scales[:, :, 3, :] * q4 + numel = functools.reduce(operator.mul, shape, 1) + # in ggml the first is the fastest moving dimension + y = y.reshape(-1)[:numel].reshape(*shape[::-1]).permute(*list(range(len(shape) - 1, -1, -1))) + return y + + +# define QK4_0 32 +# typedef struct { +# ggml_half d; // delta +# uint8_t qs[QK4_0 / 2]; // nibbles / quants +# } block_q4_0; +@register_ggml_dequantizer(GgmlType.Q4_0) +def dequantize_q4_0(block_data, shape, dtype=torch.float32): + num_blocks, num_block_bytes = block_data.shape + assert num_block_bytes == GGML_BLOCK_SIZES[GgmlType.Q4_0] + block_q4_0_d = block_data[:, :2].view(torch.float16).to(torch.float32).reshape(num_blocks, 1) + block_q4_0_qs = block_data[:, 2:] + num_blocks, _ = block_q4_0_qs.shape + y = torch.empty((num_blocks, 2, 16), dtype=dtype, device=block_data.device) + + q0 = block_q4_0_qs & 0xF + q0 = q0.view(torch.int8) - 8 + q1 = (block_q4_0_qs >> 4).view(torch.int8) - 8 + y[:, 0, :] = block_q4_0_d * q0 + y[:, 1, :] = block_q4_0_d * q1 + numel = functools.reduce(operator.mul, shape, 1) + # in ggml the first is the fastest moving dimension + y = y.reshape(-1)[:numel].reshape(*shape[::-1]).permute(*list(range(len(shape) - 1, -1, -1))) + return y + + +@register_ggml_dequantizer(GgmlType.F32) +@register_ggml_dequantizer(GgmlType.F16) +def dequantize_noop(block_data, shape, dtype=torch.float32): + assert block_data.shape == shape + return block_data.to(dtype) + + +def dequantize(qw, typ, shape, dtype=torch.float32): + dequantizer = GGML_DEQUANTIZERS.get(typ) + if typ is None: + raise NotImplementedError("Cannot decode {typ}") + return dequantizer(qw, shape, dtype) + + +def cat_quantized(tensors, infos, dim): + assert isinstance(tensors, collections.abc.Sequence) and len(tensors) > 0 + assert isinstance(infos, collections.abc.Sequence) and len(infos) == len(tensors) + typ_0, shape_0 = infos[0] + assert typ_0 not in (GgmlType.F16, GgmlType.F32), "This only works for quantized" + total_dim = len(shape_0) + if dim < 0: + dim += total_dim + assert 0 <= dim < total_dim + concat_dim_size = 0 + reshaped_tensors = [] + for (typ_i, shape_i), tensor_i in zip(infos, tensors): + assert typ_i == typ_0 and len(shape_i) == total_dim, "concatenated tensors must have same type and dimension" + assert ( + shape_i[:dim] == shape_0[:dim] and shape_i[dim + 1 :] == shape_0[dim + 1 :] + ), "shapes must match except in the concat dimension" + concat_dim_size += shape_i[dim] + numel_in_back = functools.reduce(operator.mul, shape_i[: dim + 1]) + assert numel_in_back % GGML_ELEMENTS_PER_BLOCK[typ_0] == 0 + blocks_in_back = numel_in_back // GGML_ELEMENTS_PER_BLOCK[typ_0] + reshaped_tensors.append(tensor_i.reshape(-1, blocks_in_back, GGML_BLOCK_SIZES[typ_0])) + new_tensor = torch.cat(reshaped_tensors, dim=1).view(-1, GGML_BLOCK_SIZES[typ_0]) + new_info = (typ_0, (*shape_0[:dim], concat_dim_size, *shape_0[dim + 1 :])) + return new_tensor, new_info + + +def merge_attention_weights(qq, info_q, qk, info_k, qv, info_v): + typ_q, shape_q = info_q + typ_k, shape_k = info_k + typ_v, shape_v = info_v + + num_blocks = qq.shape[0] + assert qq.shape[1] == GGML_BLOCK_SIZES[typ_q] + assert shape_q[0] % GGML_ELEMENTS_PER_BLOCK[typ_q] == 0 + qq_swizzled = ( + qq.view(-1, 64, 2, shape_q[0] // GGML_ELEMENTS_PER_BLOCK[typ_q], qq.shape[1]).transpose(1, 2).reshape(*qq.shape) + ) + qk_swizzled = ( + qk.view(-1, 64, 2, shape_k[0] // GGML_ELEMENTS_PER_BLOCK[typ_k], qk.shape[1]).transpose(1, 2).reshape(*qk.shape) + ) + + dqq2 = dequantize(qq_swizzled, typ_q, shape_q) + dqk2 = dequantize(qk_swizzled, typ_k, shape_k) + + assert shape_k[0] % GGML_ELEMENTS_PER_BLOCK[typ_k] == 0 + assert shape_v[0] % GGML_ELEMENTS_PER_BLOCK[typ_v] == 0 + + q_all, (typ_all, shape_all) = cat_quantized( + [qq_swizzled, qk_swizzled, qv], + [ + (typ_q, (shape_q[0], shape_q[1] // 8, 8)), + (typ_k, (shape_k[0], shape_k[1] // 8, 8)), + (typ_v, (shape_v[0], shape_v[1] // 8, 8)), + ], + dim=1, + ) + shape_all = (shape_all[0], shape_all[1] * shape_all[2]) + return q_all, (typ_all, shape_all) + + +class GgmlDataReader: + def __init__(self, model_file_name): + model_file_path = pathlib.Path(model_file_name).expanduser() + data = json.load(open(model_file_path)) + (model_data,) = [d for d in data["layers"] if d["mediaType"] == "application/vnd.ollama.image.model"] + model_data_blob = model_data["digest"].replace(":", "-") + + p = model_file_path + while p.name != ".ollama": + p = p.parent + model_blob_path = p / "models" / "blobs" / model_data_blob + + self.f = open(model_blob_path, "rb") + metadata, tensor_infos = read_gguf_header(self.f) + + self.tensor_infos = {name: (typ, shape, offset) for name, shape, typ, offset in tensor_infos} + + def read_tensor(self, ggml_name): + typ, shape, offset = self.tensor_infos[ggml_name] + + self.f.seek(offset) + num_block_bytes = GGML_BLOCK_SIZES[typ] + num_blocks = compute_number_of_blocks(shape, typ) + numel = functools.reduce(operator.mul, shape, 1) + + block_data_raw = self.f.read(num_blocks * num_block_bytes) + block_data = numpy.frombuffer(block_data_raw, dtype=numpy.uint8).reshape(num_blocks, num_block_bytes) + if typ == GgmlType.F32: + # in ggml the first is the fastest moving dimension + block_data = ( + block_data.view(numpy.float32) + .reshape(-1)[:numel] + .reshape(*shape[::-1]) + .transpose(*list(range(len(shape) - 1, -1, -1))) + ) + elif typ == GgmlType.F16: + block_data = ( + block_data.view(numpy.float16) + .reshape(-1)[:numel] + .reshape(*shape[::-1]) + .transpose(*list(range(len(shape) - 1, -1, -1))) + ) + import warnings + + with warnings.catch_warnings(action="ignore", category=UserWarning): # ignore read-only warning + return torch.from_numpy(block_data), (typ, shape) + + def close(self): + self.f.close() + + def get_parameter(self, name): + ggml_name = ( + name.replace("transformer.wte.weight", "token_embd.weight") + .replace("transformer.h.", "blk.") + .replace(".norm_1.weight", ".attn_norm.weight") + .replace(".attn.attn.weight", ".attn_q.weight") + .replace(".attn.proj.weight", ".attn_output.weight") + .replace(".norm_2.weight", ".ffn_norm.weight") + .replace(".mlp.fc_1.weight", ".ffn_gate.weight") + .replace(".mlp.fc_2.weight", ".ffn_up.weight") + .replace(".mlp.proj.weight", ".ffn_down.weight") + .replace("transformer.ln_f.weight", "output_norm.weight") + .replace("lm_head.weight", "output.weight") + ) + if not name.endswith("attn.attn.weight"): + q, (typ, shape) = self.read_tensor(ggml_name) + else: + ggml_name_key = ggml_name.replace("attn_q", "attn_k") + ggml_name_value = ggml_name.replace("attn_q", "attn_v") + q, (typ, shape) = merge_attention_weights( + *self.read_tensor(ggml_name), *self.read_tensor(ggml_name_key), *self.read_tensor(ggml_name_value) + ) + return q, (typ, shape) diff --git a/examples/ggml-quant/thunder_ggmlquant.ipynb b/examples/ggml-quant/thunder_ggmlquant.ipynb new file mode 100644 index 0000000000..52964e37ea --- /dev/null +++ b/examples/ggml-quant/thunder_ggmlquant.ipynb @@ -0,0 +1,641 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e2db72f9", + "metadata": { + "hideCode": false, + "hidePrompt": false + }, + "source": [ + "# Loading GGML / Ollama weights into LitGPT" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d3b8246e", + "metadata": { + "hideCode": false, + "hidePrompt": false + }, + "outputs": [], + "source": [ + "import thunder\n", + "import torch\n", + "import ggmltensor\n", + "\n", + "def load_ggml_weights(model, fn): \n", + " ggml_quant = ggmltensor.GgmlDataReader(fn)\n", + "\n", + " for n, p in model.named_parameters(): \n", + " qw, (typ, shape) = ggml_quant.get_parameter(n)\n", + " with torch.no_grad():\n", + " w = ggmltensor.dequantize(qw, typ, shape, dtype=p.dtype).to(p.device)\n", + " p.copy_(w.t())\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5a7a129a-c646-4adf-8efe-98f217a4d786", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading model '/home/tv/data/firma/grid/thunder/litgpt/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/lit_model.pth' with {'name': 'Llama-3-8B-Instruct', 'hf_config': {'name': 'Meta-Llama-3-8B-Instruct', 'org': 'meta-llama'}, 'scale_embeddings': False, 'block_size': 8192, 'vocab_size': 128000, 'padding_multiple': 512, 'padded_vocab_size': 128256, 'n_layer': 32, 'n_head': 32, 'head_size': 128, 'n_embd': 4096, 'rotary_percentage': 1.0, 'parallel_residual': False, 'bias': False, 'lm_head_bias': False, 'n_query_groups': 8, 'shared_attention_norm': False, 'norm_class_name': 'RMSNorm', 'norm_eps': 1e-05, 'mlp_class_name': 'LLaMAMLP', 'gelu_approximate': 'none', 'intermediate_size': 14336, 'rope_condense_ratio': 1, 'rope_base': 500000, 'n_expert': 0, 'n_expert_per_token': 0, 'rope_n_elem': 128}\n", + "Time to instantiate model: 0.08 seconds.\n", + "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "Time to load the model weights: 14.32 seconds.\n" + ] + } + ], + "source": [ + "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n", + "\n", + "import sys\n", + "import time\n", + "from pathlib import Path\n", + "from typing import Any, Literal, Optional\n", + "\n", + "import lightning as L\n", + "import torch\n", + "import torch._dynamo.config\n", + "import torch._inductor.config\n", + "#from lightning.fabric.plugins import BitsandbytesPrecision\n", + "\n", + "from litgpt import GPT, Config, PromptStyle, Tokenizer\n", + "from litgpt.prompts import has_prompt_style, load_prompt_style\n", + "from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint\n", + "\n", + "\n", + "\n", + "def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:\n", + " if torch._dynamo.is_compiling():\n", + " # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly\n", + " distribution = torch.empty_like(probs).exponential_(1)\n", + " return torch.argmax(probs / distribution, dim=-1, keepdim=True)\n", + " return torch.multinomial(probs, num_samples=1)\n", + "\n", + "\n", + "def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:\n", + " sorted_logits, sorted_indices = torch.sort(logits, descending=False)\n", + " cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)\n", + " # Example:\n", + " # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]\n", + " # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7\n", + " sorted_indices_to_remove = cumulative_probs <= (1 - top_p)\n", + " # Keep at least 1 token always to prevent the case where no token is selected\n", + " # In this case the most probable one is always kept\n", + " sorted_indices_to_remove[-1:] = 0\n", + " indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)\n", + " logits = logits.masked_fill(indices_to_remove, float(\"-inf\"))\n", + " return logits\n", + "\n", + "\n", + "def sample(\n", + " logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None, top_p: float = 1.0\n", + ") -> torch.Tensor:\n", + " if top_p < 0.0 or top_p > 1.0:\n", + " raise ValueError(f\"top_p must be in [0, 1], got {top_p}\")\n", + " logits = logits[0, -1]\n", + " # optionally crop the logits to only the top k options\n", + " if top_k is not None:\n", + " v, i = torch.topk(logits, min(top_k, logits.size(-1)))\n", + " # do not use `torch.where` as in nanogpt because it will repeat top-k collisions\n", + " logits = torch.full_like(logits, float(\"-inf\")).scatter_(-1, i, v)\n", + " # optionally scale the logits and sample from a probability distribution\n", + " if temperature > 0.0 or top_p > 0.0:\n", + " if temperature > 0.0:\n", + " logits = logits / temperature\n", + " # optionally crop the logits to smallest set of logits with a cumulative probability above top_p\n", + " if top_p < 1.0:\n", + " logits = sample_top_p(logits, top_p)\n", + " probs = torch.nn.functional.softmax(logits, dim=-1)\n", + " return multinomial_num_samples_1(probs)\n", + " return torch.argmax(logits, dim=-1, keepdim=True)\n", + "\n", + "\n", + "def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:\n", + " logits = model(x, input_pos)\n", + " next = sample(logits, **kwargs)\n", + " return next.to(dtype=x.dtype)\n", + "\n", + "\n", + "@torch.inference_mode()\n", + "def generate(\n", + " model: GPT,\n", + " prompt: torch.Tensor,\n", + " max_returned_tokens: int,\n", + " *,\n", + " temperature: float = 1.0,\n", + " top_k: Optional[int] = None,\n", + " top_p: float = 1.0,\n", + " eos_id: Optional[int] = None,\n", + ") -> torch.Tensor:\n", + " \"\"\"Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.\n", + "\n", + " The implementation of this function is modified from A. Karpathy's nanoGPT.\n", + "\n", + " Args:\n", + " model: The model to use.\n", + " prompt: Tensor of shape (T) with indices of the prompt sequence.\n", + " max_returned_tokens: The maximum number of tokens to return (given plus generated).\n", + " temperature: Scales the predicted logits by 1 / temperature.\n", + " top_k: If specified, only sample among the tokens with the k highest probabilities.\n", + " top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n", + " In top-p sampling, the next token is sampled from the highest probability tokens\n", + " whose cumulative probability exceeds the threshold `top_p`. When specified,\n", + " it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n", + " to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n", + " It can be used in conjunction with `top_k` and `temperature` with the following order\n", + " of application:\n", + "\n", + " 1. `top_k` sampling\n", + " 2. `temperature` scaling\n", + " 3. `top_p` sampling\n", + "\n", + " For more details, see https://arxiv.org/abs/1904.09751\n", + " or https://huyenchip.com/2024/01/16/sampling.html#top_p\n", + " eos_id: If specified, stop generating any more token once the token is triggered.\n", + " \"\"\"\n", + " T = prompt.size(0)\n", + " assert max_returned_tokens > T\n", + " if model.max_seq_length < max_returned_tokens - 1:\n", + " # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a\n", + " # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do\n", + " # not support it to avoid negatively impacting the overall speed\n", + " raise NotImplementedError(f\"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}\")\n", + "\n", + " device = prompt.device\n", + " tokens = [prompt]\n", + " input_pos = torch.tensor([T], device=device)\n", + " token = next_token(\n", + " model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p\n", + " ).clone()\n", + " tokens.append(token)\n", + " for _ in range(2, max_returned_tokens - T + 1):\n", + " token = next_token(\n", + " model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p\n", + " ).clone()\n", + " tokens.append(token)\n", + " if token == eos_id:\n", + " break\n", + " input_pos = input_pos.add_(1)\n", + " return torch.cat(tokens)\n", + "\n", + "\n", + "with torch.inference_mode():\n", + " prompt: str = \"What food do llamas eat?\"\n", + " num_samples: int = 1\n", + " max_new_tokens: int = 256\n", + " top_k: Optional[int] = 50\n", + " top_p: float = 1.0\n", + " temperature: float = 0.8\n", + " checkpoint_dir: Path = Path(\"/home/tv/data/firma/grid/thunder/litgpt/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/\")\n", + " quantize: Optional[Literal[\"bnb.nf4\", \"bnb.nf4-dq\", \"bnb.fp4\", \"bnb.fp4-dq\", \"bnb.int8\"]] = \"bnb.nf4\"\n", + " precision: Optional[str] = \"bf16-true\"\n", + " compile: bool = False\n", + "# litgpt generate base --quantize bnb.nf4 --checkpoint_dir checkpoints/tiiuae/falcon-7b --precision bf16-true --max_new_tokens 256\n", + "\n", + " \"\"\"Generates text samples based on a pre-trained model and tokenizer.\n", + "\n", + " Args:\n", + " prompt: The prompt string to use for generating the samples.\n", + " num_samples: The number of text samples to generate.\n", + " max_new_tokens: The number of generation steps to take.\n", + " top_k: The number of top most probable tokens to consider in the sampling process.\n", + " top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.\n", + " In top-p sampling, the next token is sampled from the highest probability tokens\n", + " whose cumulative probability exceeds the threshold `top_p`. When specified,\n", + " it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent\n", + " to sampling the most probable token, while `top_p=1` samples from the whole distribution.\n", + " It can be used in conjunction with `top_k` and `temperature` with the following order\n", + " of application:\n", + "\n", + " 1. `top_k` sampling\n", + " 2. `temperature` scaling\n", + " 3. `top_p` sampling\n", + "\n", + " For more details, see https://arxiv.org/abs/1904.09751\n", + " or https://huyenchip.com/2024/01/16/sampling.html#top_p\n", + " temperature: A value controlling the randomness of the sampling process. Higher values result in more random\n", + " samples.\n", + " checkpoint_dir: The checkpoint directory to load.\n", + " quantize: Whether to quantize the model and using which method:\n", + " - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes\n", + " - bnb.int8: 8-bit quantization from bitsandbytes\n", + " for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md\n", + " precision: Indicates the Fabric precision setting to use.\n", + " compile: Whether to compile the model.\n", + " \"\"\"\n", + " precision = precision or get_default_supported_precision(training=False)\n", + "\n", + " #plugins = BitsandbytesPrecision(mode='nf4', dtype=torch.bfloat16)\n", + " \n", + " precision = 'bf16-true'\n", + "\n", + " fabric = L.Fabric(devices=1, precision=precision) #, plugins=plugins)\n", + "\n", + " check_valid_checkpoint_dir(checkpoint_dir)\n", + " config = Config.from_file(checkpoint_dir / \"model_config.yaml\")\n", + "\n", + " checkpoint_path = checkpoint_dir / \"lit_model.pth\"\n", + "\n", + " tokenizer = Tokenizer(checkpoint_dir)\n", + " prompt_style = (\n", + " load_prompt_style(checkpoint_dir) if has_prompt_style(checkpoint_dir) else PromptStyle.from_config(config)\n", + " )\n", + "\n", + " prompt = prompt_style.apply(prompt)\n", + " encoded = tokenizer.encode(prompt, device=fabric.device)\n", + " prompt_length = encoded.size(0)\n", + " max_returned_tokens = prompt_length + max_new_tokens\n", + "\n", + " fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n", + " t0 = time.perf_counter()\n", + " with fabric.init_module(empty_init=True):\n", + " model = GPT(config)\n", + " fabric.print(f\"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n", + " with fabric.init_tensor():\n", + " # set the max_seq_length to limit the memory usage to what we need\n", + " model.max_seq_length = max_returned_tokens\n", + " # enable the kv cache\n", + " model.set_kv_cache(batch_size=1)\n", + " model.eval()\n", + "\n", + " if compile:\n", + " torch._dynamo.config.automatic_dynamic_shapes = True\n", + " torch._inductor.config.triton.unique_kernel_names = True\n", + " torch._inductor.config.coordinate_descent_tuning = True\n", + " global next_token\n", + " next_token = torch.compile(next_token, mode=\"reduce-overhead\")\n", + "\n", + " model = fabric.setup_module(model)\n", + " \n", + " ggml_fn = '~/.ollama/models/manifests/registry.ollama.ai/library/llama3/latest'\n", + "\n", + " t0 = time.perf_counter()\n", + " #load_checkpoint(fabric, model, checkpoint_path)\n", + " load_ggml_weights(model._original_module, ggml_fn)\n", + "\n", + " fabric.print(f\"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.\", file=sys.stderr)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "483cc7da-4db8-4953-b5c7-657b709cc8ce", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 1234\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "system\n", + "\n", + "You are a helpful assistant.\n", + "user\n", + "\n", + "What food do llamas eat?assistant\n", + "\n", + "Llamas are ruminant animals, which means they have a four-chambered stomach and primarily eat plant-based foods. Here are some of the main foods that llamas like to munch on:\n", + "\n", + "1. Grass: Llamas love to graze on grasses, including orchard grass, timothy grass, and Bermuda grass.\n", + "2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hay types, such as alfalfa, oat hay, and grass hay.\n", + "3. Grains: Llamas can also eat grains like oats, barley, and corn, but in moderation. Too much grain can lead to digestive issues.\n", + "4. Fruits and vegetables: Llamas enjoy treats like apples, carrots, sweet potatoes, and leafy greens like lettuce and spinach.\n", + "5. Minerals: Llamas require access to a mineral block or loose minerals, like calcium and phosphorus, to stay healthy.\n", + "6. Salt: Llamas need access to salt licks or a salt block to regulate their electrolyte levels.\n", + "7. Pellets: A high-fiber pellet specifically formulated for llamas can be a convenient and nutritious addition to their diet.\n", + "\n", + "Remember to always provide fresh water and a mineral block or loose minerals along\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time for inference 1: 6.42 sec total, 39.88 tokens/sec\n", + "Memory used: 18.31 GB\n" + ] + } + ], + "source": [ + "with torch.inference_mode():\n", + " L.seed_everything(1234)\n", + " for i in range(num_samples):\n", + " t0 = time.perf_counter()\n", + " y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)\n", + " t = time.perf_counter() - t0\n", + " for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", + " fabric.print(tokenizer.decode(y))\n", + " tokens_generated = y.size(0) - prompt_length\n", + " fabric.print(\n", + " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", + " )\n", + " if fabric.device.type == \"cuda\":\n", + " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" + ] + }, + { + "cell_type": "markdown", + "id": "8705fb23-8b24-488f-ad85-72e4cdd98c25", + "metadata": {}, + "source": [ + "# Thunder transform" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a76d44e0-dd65-4e6c-a747-96f2ef1e33fb", + "metadata": {}, + "outputs": [], + "source": [ + "from collections.abc import Sequence\n", + "\n", + "import thunder\n", + "from thunder.core.transform_common import EarlyTransform\n", + "from thunder.core import utils\n", + "from thunder.core import prims\n", + "import torch\n", + "\n", + "from thunder.transforms.utils import (\n", + " get_orig_and_thunder_module_proxies_from_prologue,\n", + " get_checks,\n", + " add_trace_output,\n", + ")\n", + "\n", + "import ggmltensor\n", + "\n", + "ggmlquant_executor = thunder.extend.OperatorExecutor(\"quant_ggml\", version=0.1)\n", + "\n", + "def ggmlquant_matmul_meta(x, qweight, ggmltype: int, shape):\n", + " assert isinstance(shape, Sequence) and len(shape) == 2\n", + " assert x.shape[-1] == shape[1], f\"{x.shape=}, rhs {shape=}\"\n", + " return thunder.TensorProxy(like=x, shape=(*x.shape[:-1], shape[0]))\n", + "\n", + "\n", + "def ggmlquant_matmul_impl(x, qweight, ggmltype: int, shape):\n", + " w = ggmltensor.dequantize(qweight, ggmltensor.GgmlType(ggmltype), shape, dtype=x.dtype)\n", + " return x @ w\n", + "\n", + "def ggmlquant_embed_meta(x, qweight, ggmltype: int, shape):\n", + " assert isinstance(shape, Sequence) and len(shape) == 2\n", + " # checks for mul\n", + " return thunder.TensorProxy(like=x, shape=(*x.shape, shape[1]))\n", + "\n", + "def ggmlquant_embed_impl(x, qweight, ggmltype: int, shape):\n", + " w = ggmltensor.dequantize(qweight, ggmltensor.GgmlType(ggmltype), shape, dtype=torch.bfloat16)\n", + " return torch.nn.functional.embedding(x, w.t()) \n", + "\n", + "ggmlquant_matmul = ggmlquant_executor.register_operator(\n", + " \"ggmlquant_matmul\", meta=ggmlquant_matmul_meta, fn=ggmlquant_matmul_impl\n", + ")\n", + "\n", + "ggmlquant_embed = ggmlquant_executor.register_operator(\n", + " \"ggmlquant_embed\", meta=ggmlquant_embed_meta, fn=ggmlquant_embed_impl\n", + ")\n", + "\n", + "\n", + "\n", + "class GGMLQuantTransform(EarlyTransform):\n", + " def __init__(self, model_file_name, device):\n", + " self.quant_states = {}\n", + " self.quantized_submodule_names = set()\n", + " self.device = device\n", + " self.model_file_name = model_file_name\n", + "\n", + " def transform_module(self, model: thunder.ThunderModule):\n", + " ggml_quant = ggmltensor.GgmlDataReader(self.model_file_name)\n", + " self.thunder_module = model\n", + "\n", + " def convert_layer_with_weight(tm, name):\n", + " self.quantized_submodule_names.add(name)\n", + " weight_name = f\"{name}.weight\"\n", + " w = tm.get_parameter(weight_name)\n", + " qw, (typ, shape) = ggml_quant.get_parameter(weight_name)\n", + " tm._overrides_parameters[weight_name] = qw.to(self.device)\n", + " if not qw.is_floating_point():\n", + " self.quant_states[weight_name] = {\"typ\": typ, \"shape\": shape}\n", + "\n", + " for n, submodule in model._model.named_modules():\n", + " if hasattr(submodule, \"weight\"):\n", + " convert_layer_with_weight(model, n)\n", + " ggml_quant.close()\n", + "\n", + " def transform_state_dict_for_submodule(self, model: thunder.ThunderModule, submodule_name: str, state_dict: dict):\n", + " raise NotImplementedError(\"load weights ...\")\n", + "\n", + " def transform_traces(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):\n", + " tm = self.thunder_module\n", + " from thunder.core.trace import tracectx\n", + "\n", + " checks = get_checks(prologue_trace)\n", + "\n", + " compute_producers, compute_consumers = utils.producers_and_consumers(computation_trace)\n", + "\n", + " proglogue_to_compute_outputs = prologue_trace.output[0]\n", + "\n", + " output_idxes = {id(o): i for i, o in enumerate(proglogue_to_compute_outputs)}\n", + "\n", + " computation_trace.push_scope([])\n", + " quantized_proxies: dict[int, str] = {} # id -> name\n", + "\n", + " for n, qs in self.quant_states.items():\n", + " param = tm.get_parameter(n)\n", + " check, get_param = checks[n]\n", + " quantized_proxies[id(get_param.output)] = n\n", + " # check has args: tensor, shape, device, dtype, requires_grad\n", + " proxy, _, _, _, requires_grad = check.args\n", + " thunder_device = thunder.devices.to_device(param.device)\n", + " thunder_device_str = thunder_device.device_str()\n", + " check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)\n", + " for n, param in tm.named_parameters():\n", + " if n not in self.quant_states:\n", + " check, get_param = checks[n]\n", + " proxy, _, _, _, requires_grad = check.args\n", + " thunder_device = thunder.devices.to_device(param.device)\n", + " thunder_device_str = thunder_device.device_str() \n", + " check.args = (proxy, (*param.shape,), thunder_device_str, param.dtype, False)\n", + "\n", + " new_computation_trace = thunder.core.trace.from_trace(computation_trace)\n", + "\n", + " proxies_to_replace = {}\n", + " for bsym in computation_trace.bound_symbols:\n", + " if bsym.sym == thunder.torch.linear and id(bsym.args[1]) in quantized_proxies:\n", + " assert len(bsym.args) == 3 # torch.linear(input, weight, bias)\n", + " assert bsym.args[2] is None\n", + " n = quantized_proxies[id(bsym.args[1])]\n", + " qs = self.quant_states[n]\n", + " # signature of the new symbol:\n", + " # bnb_matmul_nf4(x, qweight, bias, absmax, quant_map, blocksize, dtype, shape)\n", + " new_args = (\n", + " *bsym.args[:2],\n", + " qs[\"typ\"].value, # integer value\n", + " qs[\"shape\"],\n", + " )\n", + " mm_bsym = bsym.from_bsym(\n", + " sym=ggmlquant_matmul,\n", + " subsymbols=[],\n", + " args=new_args,\n", + " )\n", + "\n", + " new_computation_trace.bound_symbols.append(mm_bsym)\n", + " # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n", + " # preserve the \"meta\"-info like source location, header, etc.\n", + " # TODO: switch to a better solution when it is there\n", + " ggmlquant_matmul._bind_postprocess(mm_bsym)\n", + " elif bsym.sym == thunder.torch.embedding and id(bsym.args[1]) in quantized_proxies:\n", + " assert len(bsym.args) == 7 # torch.linear(input, weight, bias)\n", + " assert bsym.args[2] is None and bsym.args[3] is None\n", + " assert bsym.args[5] is False and bsym.args[6] is False\n", + " n = quantized_proxies[id(bsym.args[1])]\n", + " qs = self.quant_states[n]\n", + " new_args = (\n", + " *bsym.args[:2],\n", + " qs[\"typ\"].value, # integer value\n", + " qs[\"shape\"],\n", + " )\n", + " emb_bsym = bsym.from_bsym(\n", + " sym=ggmlquant_embed,\n", + " subsymbols=[],\n", + " args=new_args,\n", + " )\n", + "\n", + " new_computation_trace.bound_symbols.append(emb_bsym)\n", + " # we need the postprocess to set the internal state (call_ctx) because we do not bind / execute the new symbol to\n", + " # preserve the \"meta\"-info like source location, header, etc.\n", + " # TODO: switch to a better solution when it is there\n", + " ggmlquant_embed._bind_postprocess(emb_bsym)\n", + " else:\n", + " new_computation_trace.bound_symbols.append(bsym.from_bsym())\n", + "\n", + " new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance(\"quant pass\"))\n", + " return prologue_trace, new_computation_trace, epilogue_trace\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c08de032-8db4-49ce-9382-868263d00d4d", + "metadata": {}, + "outputs": [], + "source": [ + "import thunder, torch" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9c070756-4857-4741-9658-ee83b9dceeaf", + "metadata": {}, + "outputs": [], + "source": [ + "import thunder.tests.litgpt_model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "250b5edd-31d0-456c-8bf1-96fc73e4e5af", + "metadata": {}, + "outputs": [], + "source": [ + "with torch.device(\"meta\"):\n", + " m = thunder.tests.litgpt_model.GPT.from_name('Llama-3-8B-Instruct')\n", + " m.requires_grad_(False)\n", + "m.cos, m.sin = m.rope_cache(device=torch.device('cuda'))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "39264594-703f-4393-ac2d-6c88e50b5764", + "metadata": {}, + "outputs": [], + "source": [ + "model_file_name = '~/.ollama/models/manifests/registry.ollama.ai/library/llama3/latest'\n", + "\n", + "quant_transform = GGMLQuantTransform(model_file_name, torch.device(\"cuda\"))\n", + "tm = thunder.jit(m, early_transforms=[quant_transform])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a87b6c62-afa7-4537-90ac-de99b2a7c4af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 5.6120e-03, -5.2586e-01, 4.9149e-01, ..., -1.3942e-01,\n", + " -1.3927e-01, -1.3927e-01],\n", + " [ 1.3441e+00, 1.6912e+00, 2.4314e+00, ..., -1.2766e+00,\n", + " -1.2766e+00, -1.2766e+00],\n", + " [ 5.4205e+00, 6.1258e+00, 6.0442e+00, ..., -3.8762e+00,\n", + " -3.8776e+00, -3.8776e+00],\n", + " ...,\n", + " [ 8.1336e+00, 8.4065e+00, 8.6552e+00, ..., -2.5787e+00,\n", + " -2.5798e+00, -2.5798e+00],\n", + " [ 8.2335e+00, 8.9660e+00, 8.8519e+00, ..., -2.9638e+00,\n", + " -2.9643e+00, -2.9643e+00],\n", + " [ 8.0281e+00, 8.1912e+00, 8.4859e+00, ..., -2.8124e+00,\n", + " -2.8130e+00, -2.8130e+00]]], device='cuda:0')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = torch.randint(1, 100, (1, 64), device=\"cuda\")\n", + "tm(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "945bae75-03bb-43db-9681-db0cc4b4e80b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_code_all_hidden": false, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/thunder/core/utils.py b/thunder/core/utils.py index a094fc6548..271dcdf301 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -666,7 +666,8 @@ def _reify(x): # TODO: improve device handling by canonicalizing devices and expressing them per langctx # TODO: should the comparison between devices be ==? def check_same_device(*args): - devices = tuple(x.device for x in args if isinstance(x, TensorProxyInterface)) + devices = tuple(x.device for x in args if isinstance(x, TensorProxyInterface) and x.device.type != "meta") + if len(devices) > 1: device = devices[0] for otherdevice in devices[1:]: From 8f831c3a95ed1f61f76cb8ccc1a239cd1dfb5a7c Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 16 Jul 2024 16:35:47 +0200 Subject: [PATCH 2/4] ad hoc benchmarking --- examples/ggml-quant/thunder_ggmlquant.ipynb | 606 ++++++++++++++++---- thunder/core/prims.py | 4 +- 2 files changed, 508 insertions(+), 102 deletions(-) diff --git a/examples/ggml-quant/thunder_ggmlquant.ipynb b/examples/ggml-quant/thunder_ggmlquant.ipynb index 52964e37ea..3a694639fc 100644 --- a/examples/ggml-quant/thunder_ggmlquant.ipynb +++ b/examples/ggml-quant/thunder_ggmlquant.ipynb @@ -38,20 +38,9 @@ { "cell_type": "code", "execution_count": 2, - "id": "5a7a129a-c646-4adf-8efe-98f217a4d786", + "id": "51d82e38-ddcc-443b-b846-1d607404d7b7", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading model '/home/tv/data/firma/grid/thunder/litgpt/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/lit_model.pth' with {'name': 'Llama-3-8B-Instruct', 'hf_config': {'name': 'Meta-Llama-3-8B-Instruct', 'org': 'meta-llama'}, 'scale_embeddings': False, 'block_size': 8192, 'vocab_size': 128000, 'padding_multiple': 512, 'padded_vocab_size': 128256, 'n_layer': 32, 'n_head': 32, 'head_size': 128, 'n_embd': 4096, 'rotary_percentage': 1.0, 'parallel_residual': False, 'bias': False, 'lm_head_bias': False, 'n_query_groups': 8, 'shared_attention_norm': False, 'norm_class_name': 'RMSNorm', 'norm_eps': 1e-05, 'mlp_class_name': 'LLaMAMLP', 'gelu_approximate': 'none', 'intermediate_size': 14336, 'rope_condense_ratio': 1, 'rope_base': 500000, 'n_expert': 0, 'n_expert_per_token': 0, 'rope_n_elem': 128}\n", - "Time to instantiate model: 0.08 seconds.\n", - "You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "Time to load the model weights: 14.32 seconds.\n" - ] - } - ], + "outputs": [], "source": [ "# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.\n", "\n", @@ -184,9 +173,16 @@ " if token == eos_id:\n", " break\n", " input_pos = input_pos.add_(1)\n", - " return torch.cat(tokens)\n", - "\n", - "\n", + " return torch.cat(tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "39943398-afaa-40bb-8de1-39d021914245", + "metadata": {}, + "outputs": [], + "source": [ "with torch.inference_mode():\n", " prompt: str = \"What food do llamas eat?\"\n", " num_samples: int = 1\n", @@ -203,7 +199,8 @@ " \"\"\"Generates text samples based on a pre-trained model and tokenizer.\n", "\n", " Args:\n", - " prompt: The prompt string to use for generating the samples.\n", + " prompt: The prompt strin\n", + " g to use for generating the samples.\n", " num_samples: The number of text samples to generate.\n", " max_new_tokens: The number of generation steps to take.\n", " top_k: The number of top most probable tokens to consider in the sampling process.\n", @@ -252,8 +249,17 @@ " prompt = prompt_style.apply(prompt)\n", " encoded = tokenizer.encode(prompt, device=fabric.device)\n", " prompt_length = encoded.size(0)\n", - " max_returned_tokens = prompt_length + max_new_tokens\n", - "\n", + " max_returned_tokens = prompt_length + max_new_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7afe7f82-d468-4570-83f5-9bfc4817183c", + "metadata": {}, + "outputs": [], + "source": [ + "if 0:\n", " fabric.print(f\"Loading model {str(checkpoint_path)!r} with {config.__dict__}\", file=sys.stderr)\n", " t0 = time.perf_counter()\n", " with fabric.init_module(empty_init=True):\n", @@ -287,66 +293,27 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "483cc7da-4db8-4953-b5c7-657b709cc8ce", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Seed set to 1234\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "system\n", - "\n", - "You are a helpful assistant.\n", - "user\n", - "\n", - "What food do llamas eat?assistant\n", - "\n", - "Llamas are ruminant animals, which means they have a four-chambered stomach and primarily eat plant-based foods. Here are some of the main foods that llamas like to munch on:\n", - "\n", - "1. Grass: Llamas love to graze on grasses, including orchard grass, timothy grass, and Bermuda grass.\n", - "2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hay types, such as alfalfa, oat hay, and grass hay.\n", - "3. Grains: Llamas can also eat grains like oats, barley, and corn, but in moderation. Too much grain can lead to digestive issues.\n", - "4. Fruits and vegetables: Llamas enjoy treats like apples, carrots, sweet potatoes, and leafy greens like lettuce and spinach.\n", - "5. Minerals: Llamas require access to a mineral block or loose minerals, like calcium and phosphorus, to stay healthy.\n", - "6. Salt: Llamas need access to salt licks or a salt block to regulate their electrolyte levels.\n", - "7. Pellets: A high-fiber pellet specifically formulated for llamas can be a convenient and nutritious addition to their diet.\n", - "\n", - "Remember to always provide fresh water and a mineral block or loose minerals along\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Time for inference 1: 6.42 sec total, 39.88 tokens/sec\n", - "Memory used: 18.31 GB\n" - ] - } - ], + "outputs": [], "source": [ - "with torch.inference_mode():\n", - " L.seed_everything(1234)\n", - " for i in range(num_samples):\n", - " t0 = time.perf_counter()\n", - " y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)\n", - " t = time.perf_counter() - t0\n", - " for block in model.transformer.h:\n", - " block.attn.kv_cache.reset_parameters()\n", - " fabric.print(tokenizer.decode(y))\n", - " tokens_generated = y.size(0) - prompt_length\n", - " fabric.print(\n", - " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", - " )\n", - " if fabric.device.type == \"cuda\":\n", - " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" + "if 0:\n", + " with torch.inference_mode():\n", + " L.seed_everything(1234)\n", + " for i in range(num_samples):\n", + " t0 = time.perf_counter()\n", + " y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)\n", + " t = time.perf_counter() - t0\n", + " for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", + " fabric.print(tokenizer.decode(y))\n", + " tokens_generated = y.size(0) - prompt_length\n", + " fabric.print(\n", + " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", + " )\n", + " if fabric.device.type == \"cuda\":\n", + " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" ] }, { @@ -359,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "id": "a76d44e0-dd65-4e6c-a747-96f2ef1e33fb", "metadata": {}, "outputs": [], @@ -529,7 +496,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 7, "id": "c08de032-8db4-49ce-9382-868263d00d4d", "metadata": {}, "outputs": [], @@ -539,7 +506,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "id": "9c070756-4857-4741-9658-ee83b9dceeaf", "metadata": {}, "outputs": [], @@ -549,20 +516,27 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "id": "250b5edd-31d0-456c-8bf1-96fc73e4e5af", "metadata": {}, "outputs": [], "source": [ + "import litgpt\n", "with torch.device(\"meta\"):\n", " m = thunder.tests.litgpt_model.GPT.from_name('Llama-3-8B-Instruct')\n", " m.requires_grad_(False)\n", - "m.cos, m.sin = m.rope_cache(device=torch.device('cuda'))" + " #del m.transformer.h[2:]\n", + "# enable the kv cache\n", + "device = \"cuda\"\n", + "with torch.device(device):\n", + " m.max_seq_length = max_returned_tokens\n", + " m.set_kv_cache(batch_size=1)\n", + "m.cos, m.sin = m.rope_cache(device=torch.device('cuda'))\n" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "id": "39264594-703f-4393-ac2d-6c88e50b5764", "metadata": {}, "outputs": [], @@ -575,42 +549,474 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "id": "a87b6c62-afa7-4537-90ac-de99b2a7c4af", "metadata": {}, + "outputs": [], + "source": [ + "#a = torch.randint(1, 100, (1, 64), device=\"cuda\")\n", + "#tm(a)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a7a0c755-77a3-44a4-8f11-0b3d308c75f2", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "model = tm\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a2d8666d-387f-4676-a01b-04c8567a9c5a", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Seed set to 1234\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", + " 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], device='cuda:0')\n", + "tensor([30], device='cuda:0')\n", + "tensor([31], device='cuda:0')\n", + "tensor([32], device='cuda:0')\n", + "tensor([33], device='cuda:0')\n", + "tensor([34], device='cuda:0')\n", + "tensor([35], device='cuda:0')\n", + "tensor([36], device='cuda:0')\n", + "tensor([37], device='cuda:0')\n", + "tensor([38], device='cuda:0')\n", + "tensor([39], device='cuda:0')\n", + "tensor([40], device='cuda:0')\n", + "tensor([41], device='cuda:0')\n", + "tensor([42], device='cuda:0')\n", + "tensor([43], device='cuda:0')\n", + "tensor([44], device='cuda:0')\n", + "tensor([45], device='cuda:0')\n", + "tensor([46], device='cuda:0')\n", + "tensor([47], device='cuda:0')\n", + "tensor([48], device='cuda:0')\n", + "tensor([49], device='cuda:0')\n", + "tensor([50], device='cuda:0')\n", + "tensor([51], device='cuda:0')\n", + "tensor([52], device='cuda:0')\n", + "tensor([53], device='cuda:0')\n", + "tensor([54], device='cuda:0')\n", + "tensor([55], device='cuda:0')\n", + "tensor([56], device='cuda:0')\n", + "tensor([57], device='cuda:0')\n", + "tensor([58], device='cuda:0')\n", + "tensor([59], device='cuda:0')\n", + "tensor([60], device='cuda:0')\n", + "tensor([61], device='cuda:0')\n", + "tensor([62], device='cuda:0')\n", + "tensor([63], device='cuda:0')\n", + "tensor([64], device='cuda:0')\n", + "tensor([65], device='cuda:0')\n", + "tensor([66], device='cuda:0')\n", + "tensor([67], device='cuda:0')\n", + "tensor([68], device='cuda:0')\n", + "tensor([69], device='cuda:0')\n", + "tensor([70], device='cuda:0')\n", + "tensor([71], device='cuda:0')\n", + "tensor([72], device='cuda:0')\n", + "tensor([73], device='cuda:0')\n", + "tensor([74], device='cuda:0')\n", + "tensor([75], device='cuda:0')\n", + "tensor([76], device='cuda:0')\n", + "tensor([77], device='cuda:0')\n", + "tensor([78], device='cuda:0')\n", + "tensor([79], device='cuda:0')\n", + "tensor([80], device='cuda:0')\n", + "tensor([81], device='cuda:0')\n", + "tensor([82], device='cuda:0')\n", + "tensor([83], device='cuda:0')\n", + "tensor([84], device='cuda:0')\n", + "tensor([85], device='cuda:0')\n", + "tensor([86], device='cuda:0')\n", + "tensor([87], device='cuda:0')\n", + "tensor([88], device='cuda:0')\n", + "tensor([89], device='cuda:0')\n", + "tensor([90], device='cuda:0')\n", + "tensor([91], device='cuda:0')\n", + "tensor([92], device='cuda:0')\n", + "tensor([93], device='cuda:0')\n", + "tensor([94], device='cuda:0')\n", + "tensor([95], device='cuda:0')\n", + "tensor([96], device='cuda:0')\n", + "tensor([97], device='cuda:0')\n", + "tensor([98], device='cuda:0')\n", + "tensor([99], device='cuda:0')\n", + "tensor([100], device='cuda:0')\n", + "tensor([101], device='cuda:0')\n", + "tensor([102], device='cuda:0')\n", + "tensor([103], device='cuda:0')\n", + "tensor([104], device='cuda:0')\n", + "tensor([105], device='cuda:0')\n", + "tensor([106], device='cuda:0')\n", + "tensor([107], device='cuda:0')\n", + "tensor([108], device='cuda:0')\n", + "tensor([109], device='cuda:0')\n", + "tensor([110], device='cuda:0')\n", + "tensor([111], device='cuda:0')\n", + "tensor([112], device='cuda:0')\n", + "tensor([113], device='cuda:0')\n", + "tensor([114], device='cuda:0')\n", + "tensor([115], device='cuda:0')\n", + "tensor([116], device='cuda:0')\n", + "tensor([117], device='cuda:0')\n", + "tensor([118], device='cuda:0')\n", + "tensor([119], device='cuda:0')\n", + "tensor([120], device='cuda:0')\n", + "tensor([121], device='cuda:0')\n", + "tensor([122], device='cuda:0')\n", + "tensor([123], device='cuda:0')\n", + "tensor([124], device='cuda:0')\n", + "tensor([125], device='cuda:0')\n", + "tensor([126], device='cuda:0')\n", + "tensor([127], device='cuda:0')\n", + "tensor([128], device='cuda:0')\n", + "tensor([129], device='cuda:0')\n", + "tensor([130], device='cuda:0')\n", + "tensor([131], device='cuda:0')\n", + "tensor([132], device='cuda:0')\n", + "tensor([133], device='cuda:0')\n", + "tensor([134], device='cuda:0')\n", + "tensor([135], device='cuda:0')\n", + "tensor([136], device='cuda:0')\n", + "tensor([137], device='cuda:0')\n", + "tensor([138], device='cuda:0')\n", + "tensor([139], device='cuda:0')\n", + "tensor([140], device='cuda:0')\n", + "tensor([141], device='cuda:0')\n", + "tensor([142], device='cuda:0')\n", + "tensor([143], device='cuda:0')\n", + "tensor([144], device='cuda:0')\n", + "tensor([145], device='cuda:0')\n", + "tensor([146], device='cuda:0')\n", + "tensor([147], device='cuda:0')\n", + "tensor([148], device='cuda:0')\n", + "tensor([149], device='cuda:0')\n", + "tensor([150], device='cuda:0')\n", + "tensor([151], device='cuda:0')\n", + "tensor([152], device='cuda:0')\n", + "tensor([153], device='cuda:0')\n", + "tensor([154], device='cuda:0')\n", + "tensor([155], device='cuda:0')\n", + "tensor([156], device='cuda:0')\n", + "tensor([157], device='cuda:0')\n", + "tensor([158], device='cuda:0')\n", + "tensor([159], device='cuda:0')\n", + "tensor([160], device='cuda:0')\n", + "tensor([161], device='cuda:0')\n", + "tensor([162], device='cuda:0')\n", + "tensor([163], device='cuda:0')\n", + "tensor([164], device='cuda:0')\n", + "tensor([165], device='cuda:0')\n", + "tensor([166], device='cuda:0')\n", + "tensor([167], device='cuda:0')\n", + "tensor([168], device='cuda:0')\n", + "tensor([169], device='cuda:0')\n", + "tensor([170], device='cuda:0')\n", + "tensor([171], device='cuda:0')\n", + "tensor([172], device='cuda:0')\n", + "tensor([173], device='cuda:0')\n", + "tensor([174], device='cuda:0')\n", + "tensor([175], device='cuda:0')\n", + "tensor([176], device='cuda:0')\n", + "tensor([177], device='cuda:0')\n", + "tensor([178], device='cuda:0')\n", + "tensor([179], device='cuda:0')\n", + "tensor([180], device='cuda:0')\n", + "tensor([181], device='cuda:0')\n", + "tensor([182], device='cuda:0')\n", + "tensor([183], device='cuda:0')\n", + "tensor([184], device='cuda:0')\n", + "tensor([185], device='cuda:0')\n", + "tensor([186], device='cuda:0')\n", + "tensor([187], device='cuda:0')\n", + "tensor([188], device='cuda:0')\n", + "tensor([189], device='cuda:0')\n", + "tensor([190], device='cuda:0')\n", + "tensor([191], device='cuda:0')\n", + "tensor([192], device='cuda:0')\n", + "tensor([193], device='cuda:0')\n", + "tensor([194], device='cuda:0')\n", + "tensor([195], device='cuda:0')\n", + "tensor([196], device='cuda:0')\n", + "tensor([197], device='cuda:0')\n", + "tensor([198], device='cuda:0')\n", + "tensor([199], device='cuda:0')\n", + "tensor([200], device='cuda:0')\n", + "tensor([201], device='cuda:0')\n", + "tensor([202], device='cuda:0')\n", + "tensor([203], device='cuda:0')\n", + "tensor([204], device='cuda:0')\n", + "tensor([205], device='cuda:0')\n", + "tensor([206], device='cuda:0')\n", + "tensor([207], device='cuda:0')\n", + "tensor([208], device='cuda:0')\n", + "tensor([209], device='cuda:0')\n", + "tensor([210], device='cuda:0')\n", + "tensor([211], device='cuda:0')\n", + "tensor([212], device='cuda:0')\n", + "tensor([213], device='cuda:0')\n", + "tensor([214], device='cuda:0')\n", + "tensor([215], device='cuda:0')\n", + "tensor([216], device='cuda:0')\n", + "tensor([217], device='cuda:0')\n", + "tensor([218], device='cuda:0')\n", + "tensor([219], device='cuda:0')\n", + "tensor([220], device='cuda:0')\n", + "tensor([221], device='cuda:0')\n", + "tensor([222], device='cuda:0')\n", + "tensor([223], device='cuda:0')\n", + "tensor([224], device='cuda:0')\n", + "tensor([225], device='cuda:0')\n", + "tensor([226], device='cuda:0')\n", + "tensor([227], device='cuda:0')\n", + "tensor([228], device='cuda:0')\n", + "tensor([229], device='cuda:0')\n", + "tensor([230], device='cuda:0')\n", + "tensor([231], device='cuda:0')\n", + "tensor([232], device='cuda:0')\n", + "tensor([233], device='cuda:0')\n", + "tensor([234], device='cuda:0')\n", + "tensor([235], device='cuda:0')\n", + "tensor([236], device='cuda:0')\n", + "tensor([237], device='cuda:0')\n", + "tensor([238], device='cuda:0')\n", + "tensor([239], device='cuda:0')\n", + "tensor([240], device='cuda:0')\n", + "tensor([241], device='cuda:0')\n", + "tensor([242], device='cuda:0')\n", + "tensor([243], device='cuda:0')\n", + "tensor([244], device='cuda:0')\n", + "tensor([245], device='cuda:0')\n", + "tensor([246], device='cuda:0')\n", + "tensor([247], device='cuda:0')\n", + "tensor([248], device='cuda:0')\n", + "tensor([249], device='cuda:0')\n", + "tensor([250], device='cuda:0')\n", + "tensor([251], device='cuda:0')\n", + "tensor([252], device='cuda:0')\n", + "tensor([253], device='cuda:0')\n", + "tensor([254], device='cuda:0')\n", + "tensor([255], device='cuda:0')\n", + "tensor([256], device='cuda:0')\n", + "tensor([257], device='cuda:0')\n", + "tensor([258], device='cuda:0')\n", + "tensor([259], device='cuda:0')\n", + "tensor([260], device='cuda:0')\n", + "tensor([261], device='cuda:0')\n", + "tensor([262], device='cuda:0')\n", + "tensor([263], device='cuda:0')\n", + "tensor([264], device='cuda:0')\n", + "tensor([265], device='cuda:0')\n", + "tensor([266], device='cuda:0')\n", + "tensor([267], device='cuda:0')\n", + "tensor([268], device='cuda:0')\n", + "tensor([269], device='cuda:0')\n", + "tensor([270], device='cuda:0')\n", + "tensor([271], device='cuda:0')\n", + "tensor([272], device='cuda:0')\n", + "tensor([273], device='cuda:0')\n", + "tensor([274], device='cuda:0')\n", + "tensor([275], device='cuda:0')\n", + "tensor([276], device='cuda:0')\n", + "tensor([277], device='cuda:0')\n", + "tensor([278], device='cuda:0')\n", + "tensor([279], device='cuda:0')\n", + "tensor([280], device='cuda:0')\n", + "tensor([281], device='cuda:0')\n", + "tensor([282], device='cuda:0')\n", + "tensor([283], device='cuda:0')\n", + "tensor([284], device='cuda:0')\n", + "system\n", + "\n", + "You are a helpful assistant.\n", + "user\n", + "\n", + "What food do llamas eat?assistant\n", + "\n", + "I'mamas are domesticated animals, which means they have a unique four-part stomachered digestive system that eat plants-based material. Their main food items they do llamas eat:\n", + "\n", + "1:\n", + "\n", + "1:\n", + "\n", + "1. Their diet: Llamamas love to on grass, so they eat grasses of grasses, grass, alfalfa grass, and other grasses. They can be Hay: In the main staple food for llama's diet. They dry hay, such as alfalfa, and other hay hay that's hay and legumes hay.\n", + "3. Grains: They canas also eat various grains, such as oats, barley, and corn. These should not be a part of grains in moderation as they becoming too much.\n", + ". Shilled: Llegumes: Lamas may enjoy fruits, apric, and other fruits. Some people also can eat leaf stalks, collard, and brome leaves.\n", + "5. Plant leaves: In a variety of trees, and twbs of plants. Lllamas find in the form leaves, browse on branches, and leaves for them.\n", + ".\n", + "\n", + " It can also: Llamas require access to mineral supplements, such as salt, and calcium, and trace, to help the health.\n", + "\n", + "Some important minerals like water:\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Time for inference 1: 64.34 sec total, 3.98 tokens/sec\n", + "Memory used: 9.84 GB\n" + ] + } + ], + "source": [ + "num_samples = 1\n", + "with torch.inference_mode():\n", + " L.seed_everything(1234)\n", + " for i in range(num_samples):\n", + " t0 = time.perf_counter()\n", + " y = generate(model, encoded, max_returned_tokens, temperature=temperature, top_k=top_k, top_p=top_p, eos_id=tokenizer.eos_id)\n", + " t = time.perf_counter() - t0\n", + " for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", + " fabric.print(tokenizer.decode(y))\n", + " tokens_generated = y.size(0) - prompt_length\n", + " fabric.print(\n", + " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", + " )\n", + " if fabric.device.type == \"cuda\":\n", + " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92cdfcae-4fc1-434e-acfc-27f9d5c53942", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "65ad6791-a375-4cbe-b74b-2bf08b202d0e", + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[[ 5.6120e-03, -5.2586e-01, 4.9149e-01, ..., -1.3942e-01,\n", - " -1.3927e-01, -1.3927e-01],\n", - " [ 1.3441e+00, 1.6912e+00, 2.4314e+00, ..., -1.2766e+00,\n", - " -1.2766e+00, -1.2766e+00],\n", - " [ 5.4205e+00, 6.1258e+00, 6.0442e+00, ..., -3.8762e+00,\n", - " -3.8776e+00, -3.8776e+00],\n", - " ...,\n", - " [ 8.1336e+00, 8.4065e+00, 8.6552e+00, ..., -2.5787e+00,\n", - " -2.5798e+00, -2.5798e+00],\n", - " [ 8.2335e+00, 8.9660e+00, 8.8519e+00, ..., -2.9638e+00,\n", - " -2.9643e+00, -2.9643e+00],\n", - " [ 8.0281e+00, 8.1912e+00, 8.4859e+00, ..., -2.8124e+00,\n", - " -2.8130e+00, -2.8130e+00]]], device='cuda:0')" + "torch.Size([1, 32, 286, 128])" ] }, - "execution_count": 7, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "a = torch.randint(1, 100, (1, 64), device=\"cuda\")\n", - "tm(a)" + "model.transformer.h[0].attn.kv_cache.v.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "278af111-7b12-4579-8efb-f8a03582c355", + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'GPT' object has no attribute '_forward_module'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[15], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_forward_module\u001b[49m\u001b[38;5;241m.\u001b[39mmax_seq_length\n", + "File \u001b[0;32m/usr/local/lib/python3.11/dist-packages/thunder/core/module.py:172\u001b[0m, in \u001b[0;36mThunderModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_model\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_modules[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_model\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m--> 172\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_model, name)\n", + "File \u001b[0;32m/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1893\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1891\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1892\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1893\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1894\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1895\u001b[0m )\n", + "\u001b[0;31mAttributeError\u001b[0m: 'GPT' object has no attribute '_forward_module'" + ] + } + ], + "source": [ + "model._forward_module.max_seq_length" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b33c1b1c-45b4-42d6-b363-9adcbeb8b017", + "metadata": {}, + "outputs": [], + "source": [ + "model._forward_module.config.n_head" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96c7e96a-8794-4219-bd9f-f4dc0e18fde0", + "metadata": {}, + "outputs": [], + "source": [ + "model._model.mask_cache" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86e8c0c4-48cc-44c0-ad21-77a870f14dcb", + "metadata": {}, + "outputs": [], + "source": [ + "m2 = thunder.tests.litgpt_model.OverridenKVCache((1, 32, 286, 128), (1, 32, 286, 128), device=torch.device(\"cuda\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cd6e24a-3c16-4d48-94e3-c1da897ea04f", + "metadata": {}, + "outputs": [], + "source": [ + "tm2 = thunder.jit(m2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1a4cd5f-a4d2-458d-b51c-7cb679049655", + "metadata": {}, + "outputs": [], + "source": [ + "input_pos = torch.tensor([2], device=\"cuda\")\n", + "k, v = torch.randn(2, 1, 32, 1, 128, device=\"cuda\")\n", + "tm2(input_pos, k, v)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9a80452-8852-45ea-bad9-04048336d0ef", + "metadata": {}, + "outputs": [], + "source": [ + "m2" ] }, { "cell_type": "code", "execution_count": null, - "id": "945bae75-03bb-43db-9681-db0cc4b4e80b", + "id": "a90b8686-678b-4982-abc9-6bff19abfc45", "metadata": {}, "outputs": [], "source": [] diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 87fdee534d..6cd42c1ee9 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -3624,7 +3624,7 @@ def linear_meta(a: TensorProxy, w: TensorProxy, bias: None | TensorProxy) -> Ten utils.check(isinstance(w, TensorProxy), lambda: f"w={w} was not a TensorProxy!") # Checks that required arguments are on the same device - utils.check(a.device == w.device, lambda: f"Expected a.device={a.device} and w.device={w.device} to be the same!") + utils.check_same_device(a, w) # Acquires the computation dtype and checks that a and w have the same dtype dtype = a.dtype @@ -3684,7 +3684,7 @@ def matmul_meta(a: TensorProxy, b: TensorProxy, /) -> TensorProxy: if a.ndim < 1 or b.ndim < 1: raise NotImplementedError - utils.check(a.device == b.device, lambda: f"Expected a.device={a.device} and b.device={b.device} to be the same") + utils.check_same_device(a, b) utils.check( dtypes.are_same_dtypes(a, b), lambda: f"Expected a.dtype={a.dtype} and b.dtype={b.dtype} to be the same" From 0032aa4d9e4723080bed6d684aca84d0b4ee3f21 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 16 Jul 2024 20:50:18 +0200 Subject: [PATCH 3/4] wip --- examples/ggml-quant/thunder_ggmlquant.ipynb | 362 +++++--------------- thunder/tests/litgpt_model.py | 4 +- 2 files changed, 85 insertions(+), 281 deletions(-) diff --git a/examples/ggml-quant/thunder_ggmlquant.ipynb b/examples/ggml-quant/thunder_ggmlquant.ipynb index 3a694639fc..53852b01ab 100644 --- a/examples/ggml-quant/thunder_ggmlquant.ipynb +++ b/examples/ggml-quant/thunder_ggmlquant.ipynb @@ -572,11 +572,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "id": "a2d8666d-387f-4676-a01b-04c8567a9c5a", - "metadata": { - "scrolled": true - }, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -589,263 +587,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n", - " 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], device='cuda:0')\n", - "tensor([30], device='cuda:0')\n", - "tensor([31], device='cuda:0')\n", - "tensor([32], device='cuda:0')\n", - "tensor([33], device='cuda:0')\n", - "tensor([34], device='cuda:0')\n", - "tensor([35], device='cuda:0')\n", - "tensor([36], device='cuda:0')\n", - "tensor([37], device='cuda:0')\n", - "tensor([38], device='cuda:0')\n", - "tensor([39], device='cuda:0')\n", - "tensor([40], device='cuda:0')\n", - "tensor([41], device='cuda:0')\n", - "tensor([42], device='cuda:0')\n", - "tensor([43], device='cuda:0')\n", - "tensor([44], device='cuda:0')\n", - "tensor([45], device='cuda:0')\n", - "tensor([46], device='cuda:0')\n", - "tensor([47], device='cuda:0')\n", - "tensor([48], device='cuda:0')\n", - "tensor([49], device='cuda:0')\n", - "tensor([50], device='cuda:0')\n", - "tensor([51], device='cuda:0')\n", - "tensor([52], device='cuda:0')\n", - "tensor([53], device='cuda:0')\n", - "tensor([54], device='cuda:0')\n", - "tensor([55], device='cuda:0')\n", - "tensor([56], device='cuda:0')\n", - "tensor([57], device='cuda:0')\n", - "tensor([58], device='cuda:0')\n", - "tensor([59], device='cuda:0')\n", - "tensor([60], device='cuda:0')\n", - "tensor([61], device='cuda:0')\n", - "tensor([62], device='cuda:0')\n", - "tensor([63], device='cuda:0')\n", - "tensor([64], device='cuda:0')\n", - "tensor([65], device='cuda:0')\n", - "tensor([66], device='cuda:0')\n", - "tensor([67], device='cuda:0')\n", - "tensor([68], device='cuda:0')\n", - "tensor([69], device='cuda:0')\n", - "tensor([70], device='cuda:0')\n", - "tensor([71], device='cuda:0')\n", - "tensor([72], device='cuda:0')\n", - "tensor([73], device='cuda:0')\n", - "tensor([74], device='cuda:0')\n", - "tensor([75], device='cuda:0')\n", - "tensor([76], device='cuda:0')\n", - "tensor([77], device='cuda:0')\n", - "tensor([78], device='cuda:0')\n", - "tensor([79], device='cuda:0')\n", - "tensor([80], device='cuda:0')\n", - "tensor([81], device='cuda:0')\n", - "tensor([82], device='cuda:0')\n", - "tensor([83], device='cuda:0')\n", - "tensor([84], device='cuda:0')\n", - "tensor([85], device='cuda:0')\n", - "tensor([86], device='cuda:0')\n", - "tensor([87], device='cuda:0')\n", - "tensor([88], device='cuda:0')\n", - "tensor([89], device='cuda:0')\n", - "tensor([90], device='cuda:0')\n", - "tensor([91], device='cuda:0')\n", - "tensor([92], device='cuda:0')\n", - "tensor([93], device='cuda:0')\n", - "tensor([94], device='cuda:0')\n", - "tensor([95], device='cuda:0')\n", - "tensor([96], device='cuda:0')\n", - "tensor([97], device='cuda:0')\n", - "tensor([98], device='cuda:0')\n", - "tensor([99], device='cuda:0')\n", - "tensor([100], device='cuda:0')\n", - "tensor([101], device='cuda:0')\n", - "tensor([102], device='cuda:0')\n", - "tensor([103], device='cuda:0')\n", - "tensor([104], device='cuda:0')\n", - "tensor([105], device='cuda:0')\n", - "tensor([106], device='cuda:0')\n", - "tensor([107], device='cuda:0')\n", - "tensor([108], device='cuda:0')\n", - "tensor([109], device='cuda:0')\n", - "tensor([110], device='cuda:0')\n", - "tensor([111], device='cuda:0')\n", - "tensor([112], device='cuda:0')\n", - "tensor([113], device='cuda:0')\n", - "tensor([114], device='cuda:0')\n", - "tensor([115], device='cuda:0')\n", - "tensor([116], device='cuda:0')\n", - "tensor([117], device='cuda:0')\n", - "tensor([118], device='cuda:0')\n", - "tensor([119], device='cuda:0')\n", - "tensor([120], device='cuda:0')\n", - "tensor([121], device='cuda:0')\n", - "tensor([122], device='cuda:0')\n", - "tensor([123], device='cuda:0')\n", - "tensor([124], device='cuda:0')\n", - "tensor([125], device='cuda:0')\n", - "tensor([126], device='cuda:0')\n", - "tensor([127], device='cuda:0')\n", - "tensor([128], device='cuda:0')\n", - "tensor([129], device='cuda:0')\n", - "tensor([130], device='cuda:0')\n", - "tensor([131], device='cuda:0')\n", - "tensor([132], device='cuda:0')\n", - "tensor([133], device='cuda:0')\n", - "tensor([134], device='cuda:0')\n", - "tensor([135], device='cuda:0')\n", - "tensor([136], device='cuda:0')\n", - "tensor([137], device='cuda:0')\n", - "tensor([138], device='cuda:0')\n", - "tensor([139], device='cuda:0')\n", - "tensor([140], device='cuda:0')\n", - "tensor([141], device='cuda:0')\n", - "tensor([142], device='cuda:0')\n", - "tensor([143], device='cuda:0')\n", - "tensor([144], device='cuda:0')\n", - "tensor([145], device='cuda:0')\n", - "tensor([146], device='cuda:0')\n", - "tensor([147], device='cuda:0')\n", - "tensor([148], device='cuda:0')\n", - "tensor([149], device='cuda:0')\n", - "tensor([150], device='cuda:0')\n", - "tensor([151], device='cuda:0')\n", - "tensor([152], device='cuda:0')\n", - "tensor([153], device='cuda:0')\n", - "tensor([154], device='cuda:0')\n", - "tensor([155], device='cuda:0')\n", - "tensor([156], device='cuda:0')\n", - "tensor([157], device='cuda:0')\n", - "tensor([158], device='cuda:0')\n", - "tensor([159], device='cuda:0')\n", - "tensor([160], device='cuda:0')\n", - "tensor([161], device='cuda:0')\n", - "tensor([162], device='cuda:0')\n", - "tensor([163], device='cuda:0')\n", - "tensor([164], device='cuda:0')\n", - "tensor([165], device='cuda:0')\n", - "tensor([166], device='cuda:0')\n", - "tensor([167], device='cuda:0')\n", - "tensor([168], device='cuda:0')\n", - "tensor([169], device='cuda:0')\n", - "tensor([170], device='cuda:0')\n", - "tensor([171], device='cuda:0')\n", - "tensor([172], device='cuda:0')\n", - "tensor([173], device='cuda:0')\n", - "tensor([174], device='cuda:0')\n", - "tensor([175], device='cuda:0')\n", - "tensor([176], device='cuda:0')\n", - "tensor([177], device='cuda:0')\n", - "tensor([178], device='cuda:0')\n", - "tensor([179], device='cuda:0')\n", - "tensor([180], device='cuda:0')\n", - "tensor([181], device='cuda:0')\n", - "tensor([182], device='cuda:0')\n", - "tensor([183], device='cuda:0')\n", - "tensor([184], device='cuda:0')\n", - "tensor([185], device='cuda:0')\n", - "tensor([186], device='cuda:0')\n", - "tensor([187], device='cuda:0')\n", - "tensor([188], device='cuda:0')\n", - "tensor([189], device='cuda:0')\n", - "tensor([190], device='cuda:0')\n", - "tensor([191], device='cuda:0')\n", - "tensor([192], device='cuda:0')\n", - "tensor([193], device='cuda:0')\n", - "tensor([194], device='cuda:0')\n", - "tensor([195], device='cuda:0')\n", - "tensor([196], device='cuda:0')\n", - "tensor([197], device='cuda:0')\n", - "tensor([198], device='cuda:0')\n", - "tensor([199], device='cuda:0')\n", - "tensor([200], device='cuda:0')\n", - "tensor([201], device='cuda:0')\n", - "tensor([202], device='cuda:0')\n", - "tensor([203], device='cuda:0')\n", - "tensor([204], device='cuda:0')\n", - "tensor([205], device='cuda:0')\n", - "tensor([206], device='cuda:0')\n", - "tensor([207], device='cuda:0')\n", - "tensor([208], device='cuda:0')\n", - "tensor([209], device='cuda:0')\n", - "tensor([210], device='cuda:0')\n", - "tensor([211], device='cuda:0')\n", - "tensor([212], device='cuda:0')\n", - "tensor([213], device='cuda:0')\n", - "tensor([214], device='cuda:0')\n", - "tensor([215], device='cuda:0')\n", - "tensor([216], device='cuda:0')\n", - "tensor([217], device='cuda:0')\n", - "tensor([218], device='cuda:0')\n", - "tensor([219], device='cuda:0')\n", - "tensor([220], device='cuda:0')\n", - "tensor([221], device='cuda:0')\n", - "tensor([222], device='cuda:0')\n", - "tensor([223], device='cuda:0')\n", - "tensor([224], device='cuda:0')\n", - "tensor([225], device='cuda:0')\n", - "tensor([226], device='cuda:0')\n", - "tensor([227], device='cuda:0')\n", - "tensor([228], device='cuda:0')\n", - "tensor([229], device='cuda:0')\n", - "tensor([230], device='cuda:0')\n", - "tensor([231], device='cuda:0')\n", - "tensor([232], device='cuda:0')\n", - "tensor([233], device='cuda:0')\n", - "tensor([234], device='cuda:0')\n", - "tensor([235], device='cuda:0')\n", - "tensor([236], device='cuda:0')\n", - "tensor([237], device='cuda:0')\n", - "tensor([238], device='cuda:0')\n", - "tensor([239], device='cuda:0')\n", - "tensor([240], device='cuda:0')\n", - "tensor([241], device='cuda:0')\n", - "tensor([242], device='cuda:0')\n", - "tensor([243], device='cuda:0')\n", - "tensor([244], device='cuda:0')\n", - "tensor([245], device='cuda:0')\n", - "tensor([246], device='cuda:0')\n", - "tensor([247], device='cuda:0')\n", - "tensor([248], device='cuda:0')\n", - "tensor([249], device='cuda:0')\n", - "tensor([250], device='cuda:0')\n", - "tensor([251], device='cuda:0')\n", - "tensor([252], device='cuda:0')\n", - "tensor([253], device='cuda:0')\n", - "tensor([254], device='cuda:0')\n", - "tensor([255], device='cuda:0')\n", - "tensor([256], device='cuda:0')\n", - "tensor([257], device='cuda:0')\n", - "tensor([258], device='cuda:0')\n", - "tensor([259], device='cuda:0')\n", - "tensor([260], device='cuda:0')\n", - "tensor([261], device='cuda:0')\n", - "tensor([262], device='cuda:0')\n", - "tensor([263], device='cuda:0')\n", - "tensor([264], device='cuda:0')\n", - "tensor([265], device='cuda:0')\n", - "tensor([266], device='cuda:0')\n", - "tensor([267], device='cuda:0')\n", - "tensor([268], device='cuda:0')\n", - "tensor([269], device='cuda:0')\n", - "tensor([270], device='cuda:0')\n", - "tensor([271], device='cuda:0')\n", - "tensor([272], device='cuda:0')\n", - "tensor([273], device='cuda:0')\n", - "tensor([274], device='cuda:0')\n", - "tensor([275], device='cuda:0')\n", - "tensor([276], device='cuda:0')\n", - "tensor([277], device='cuda:0')\n", - "tensor([278], device='cuda:0')\n", - "tensor([279], device='cuda:0')\n", - "tensor([280], device='cuda:0')\n", - "tensor([281], device='cuda:0')\n", - "tensor([282], device='cuda:0')\n", - "tensor([283], device='cuda:0')\n", - "tensor([284], device='cuda:0')\n", "system\n", "\n", "You are a helpful assistant.\n", @@ -853,33 +594,29 @@ "\n", "What food do llamas eat?assistant\n", "\n", - "I'mamas are domesticated animals, which means they have a unique four-part stomachered digestive system that eat plants-based material. Their main food items they do llamas eat:\n", - "\n", - "1:\n", - "\n", - "1:\n", - "\n", - "1. Their diet: Llamamas love to on grass, so they eat grasses of grasses, grass, alfalfa grass, and other grasses. They can be Hay: In the main staple food for llama's diet. They dry hay, such as alfalfa, and other hay hay that's hay and legumes hay.\n", - "3. Grains: They canas also eat various grains, such as oats, barley, and corn. These should not be a part of grains in moderation as they becoming too much.\n", - ". Shilled: Llegumes: Lamas may enjoy fruits, apric, and other fruits. Some people also can eat leaf stalks, collard, and brome leaves.\n", - "5. Plant leaves: In a variety of trees, and twbs of plants. Lllamas find in the form leaves, browse on branches, and leaves for them.\n", - ".\n", - "\n", - " It can also: Llamas require access to mineral supplements, such as salt, and calcium, and trace, to help the health.\n", - "\n", - "Some important minerals like water:\n" + "Lema of. 15.25, but the best the righth, such as the as a unique solutions for a; of the dog the dofile://def (What's sake. 2011 The Ultimate ( ( (\n", + "def (Supposeidon 4\n", + "Question 3 Please visitations on the kitchen sink =. 3rdquo;ts (\n", + "package apply for the rolex2019: The number Converse = perform.\n", + "def ( collected, the quicksandwich In the United we shall be a\n", + "Title of thelaborne, The game The Supers;ing ofsted with the what, the American Football is a complete storylines,\"data-driven from the \" target=\"_E, makestudies of the first-timeous. Findhorn in Science the United States the World Wide of the first appeared this is a solutioningenuous for the entire website, andthe - (Supers andthe most widely recognized as Philosophy onalleg work\n", + "Question the importance=\"The the first appearedition of the idea from the world is the importance of these the story andourishing andtheatre andtheoretical\n", + "def (SOME comment=20\n", + "Titley, andthe\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "Time for inference 1: 64.34 sec total, 3.98 tokens/sec\n", - "Memory used: 9.84 GB\n" + "Time for inference 1: 211.12 sec total, 1.21 tokens/sec\n", + "Memory used: 9.24 GB\n" ] } ], "source": [ + "for block in model.transformer.h:\n", + " block.attn.kv_cache.reset_parameters()\n", "num_samples = 1\n", "with torch.inference_mode():\n", " L.seed_everything(1234)\n", @@ -892,7 +629,7 @@ " fabric.print(tokenizer.decode(y))\n", " tokens_generated = y.size(0) - prompt_length\n", " fabric.print(\n", - " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", + " f\"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec\", file=sys.stderr\n", " )\n", " if fabric.device.type == \"cuda\":\n", " fabric.print(f\"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB\", file=sys.stderr)" @@ -1019,6 +756,73 @@ "id": "a90b8686-678b-4982-abc9-6bff19abfc45", "metadata": {}, "outputs": [], + "source": [ + "torch.Tensor.index_add" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "97917740-cd76-4471-b1fd-284c3b159c65", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m\n", + " \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransformer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mh\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mattn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkv_cache\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0minput_pos\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m\n", + "Define the computation performed at every call.\n", + "\n", + "Should be overridden by all subclasses.\n", + "\n", + ".. note::\n", + " Although the recipe for forward pass needs to be defined within\n", + " this function, one should call the :class:`Module` instance afterwards\n", + " instead of this since the former takes care of running the\n", + " registered hooks while the latter silently ignores them.\n", + "\u001b[0;31mSource:\u001b[0m \n", + " \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_pos\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# move the buffer to the activation dtype for when AMP is used\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# update the cache\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# NOTE: `torch._dynamo.is_compiling` is being deprecated, we should update this once all versions have `torch.compiler.is_compiling`.\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mis_compiling\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_compiling\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompiler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"is_compiling\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_dynamo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_compiling\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_compiling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# inductor doesn't support `index_add` with bfloat16\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_copy_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_copy_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# See issue: \"Support more indexing operators (index_copy and index_add)\"\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_copy_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex_copy_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_pos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;31m# THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m /usr/local/lib/python3.11/dist-packages/thunder/tests/litgpt_model.py\n", + "\u001b[0;31mType:\u001b[0m method" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "?? m.transformer.h[0].attn.kv_cache.forward" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3daf3945-91e3-4836-806f-ef8dd972515f", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/thunder/tests/litgpt_model.py b/thunder/tests/litgpt_model.py index 65aed8c386..74817cf05e 100644 --- a/thunder/tests/litgpt_model.py +++ b/thunder/tests/litgpt_model.py @@ -145,8 +145,8 @@ def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> v = self.v.index_copy_(2, input_pos, v) return k, v # See issue: "Support more indexing operators (index_copy and index_add)" - k = self.k = torch.index_add(self.k, 2, input_pos, k) - v = self.v = torch.index_add(self.v, 2, input_pos, v) + k = self.k = self.k.index_copy(2, input_pos, k) + v = self.v = self.v.index_copy(2, input_pos, v) # THUNDER bug: cannot return self.k, self.v here (may be cuda graphs related - no minimum repro) return k, v From 5275cbffeb9098c982326bb3f643de945970b8bc Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 19 Jul 2024 11:43:13 +0000 Subject: [PATCH 4/4] Fix typo --- examples/ggml-quant/thunder_ggmlquant.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/ggml-quant/thunder_ggmlquant.ipynb b/examples/ggml-quant/thunder_ggmlquant.ipynb index 42eaeb4d7c..103902fbce 100644 --- a/examples/ggml-quant/thunder_ggmlquant.ipynb +++ b/examples/ggml-quant/thunder_ggmlquant.ipynb @@ -199,8 +199,7 @@ " \"\"\"Generates text samples based on a pre-trained model and tokenizer.\n", "\n", " Args:\n", - " prompt: The prompt strin\n", - " g to use for generating the samples.\n", + " prompt: The prompt string to use for generating the samples.\n", " num_samples: The number of text samples to generate.\n", " max_new_tokens: The number of generation steps to take.\n", " top_k: The number of top most probable tokens to consider in the sampling process.\n",