From 7c1d38b8165c7bd786eaff427ed73931349b0d18 Mon Sep 17 00:00:00 2001 From: Ofir Zafrir Date: Thu, 4 Apr 2024 23:01:48 +0300 Subject: [PATCH] Add speculative decoding example to OpenVINO quantized generation notebook (#635) * Add assisted generation example to notebook * Bug fixes * Add acceptance rate measurement * Fix explanation of AR formula * Fix norm AR calc to ignore last window * Minor fixes --- .../openvino/quantized_generation_demo.ipynb | 406 +++++++++++++++++- 1 file changed, 389 insertions(+), 17 deletions(-) diff --git a/notebooks/openvino/quantized_generation_demo.ipynb b/notebooks/openvino/quantized_generation_demo.ipynb index 582b463346..7671064088 100644 --- a/notebooks/openvino/quantized_generation_demo.ipynb +++ b/notebooks/openvino/quantized_generation_demo.ipynb @@ -223,14 +223,346 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "5179c53d-0436-4ee9-9367-2625a8d3e262", + "metadata": {}, + "source": [ + "## Assisted generation\n", + "Auto-regressive language models generate outputs token by token. Assisted generation (AG) is a general name for a group of methods that speculate the next generated tokens and then use the language model to validate the speculated tokens and accept/reject them.\n", + "AG is a great method to accelerate LMs running locally on your computer as it reduces memory bandwidth requirements and can speedup generation by 1.5x-3x without any accuracy degradation.\n", + "You can read more on assisted generation here in this great [blog post](https://huggingface.co/blog/assisted-generation).\n", + "\n", + "\n", + "In this section we will present how to run Phi-2 with two AG methods that are well supported within 🤗 transformers: Prompt Lookahead Decoding (PLD) and Speculative Decoding.\n", + "\n", + "To use Phi-2 with AG we will need to export the model again with `stateful=False` as OpenVINO stateful models don't support speculative decoding yet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adc4484f-8234-4206-9f28-7a02a7444e25", + "metadata": {}, + "outputs": [], + "source": [ + "# Save the model in a different directory to set it apart from the stateful model\n", + "save_name = model_name.split(\"/\")[-1] + \"_openvino_stateless\"\n", + "\n", + "load_kwargs[\"ov_config\"][\"CACHE_DIR\"] = os.path.join(save_name, \"model_cache\")\n", + "\n", + "# Check whether the model was already exported\n", + "saved = os.path.exists(save_name)\n", + "\n", + "# We can use the same loading attributes, the only differece is the stateful attribute\n", + "stateless_model = OVModelForCausalLM.from_pretrained(\n", + " model_name if not saved else save_name,\n", + " export=not saved,\n", + " stateful=False,\n", + " **load_kwargs,\n", + ")\n", + "\n", + "# Save the exported model locally\n", + "if not saved:\n", + " stateless_model.save_pretrained(save_name)\n", + " tokenizer.save_pretrained(save_name)\n", + "\n", + "stateless_model.compile()" + ] + }, + { + "cell_type": "markdown", + "id": "98d34b03-55e0-4606-be26-5722d6868679", + "metadata": {}, + "source": [ + "### Prompt lookahead decoding\n", + "Now we will run the same example from before with PLD enabled. \n", + "PLD speculates tokens by searching the last n-gram (usually 2-gram) in the sequence inside the prompt, if we find a match, we will take the next few tokens (configured with `prompt_lookup_num_tokens`) as our speculation, if a match is not found the code will revert back to auto-regressive generation.\n", + "\n", + "We will run the same example from before with PLD. To enable PLD, we simply pass the `prompt_lookup_num_tokens` key-word argument to the `generate` function.\n", + "Note that PLD can be great when doing code completion as some sequences of tokens tend to repeat themselves in the same order, names of variables, like `for i in range(...):`, etc.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63a2c7f3-3417-4dec-981d-e99387cc18a8", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TextStreamer\n", + "\n", + "\n", + "# Tokenize the sample\n", + "inputs = tokenizer([sample], return_tensors='pt') \n", + "\n", + "out = stateless_model.generate(\n", + " **inputs,\n", + " max_new_tokens=128,\n", + " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " prompt_lookup_num_tokens=3,\n", + ") " + ] + }, + { + "cell_type": "markdown", + "id": "f0e4e211-e721-48bf-a73f-c987fd3321d3", + "metadata": {}, + "source": [ + "### Speculative decoding\n", + "Speculative Decoding was introduced in the paper [Fast Inference from Transformers via Speculative Decoding](https://arxiv.org/abs/2211.17192).\n", + "In this method the next tokens in the sequence are speculated using another smaller and much faster model which is called a draft model.\n", + "The only constraint we have on the draft model is that it has to have the same vocabulary as the target model, in our case Phi-2.\n", + "Phi-2 and CodeGen models share the same vocabulary and therefore we can use a much smaller and faster CodeGen model as a draft model to Phi-2.\n", + "A common metric for assessing if a draft model is performing well is the acceptance rate.\n", + "The acceptance rate measures how many tokens out of the speculated tokens in each window are accepted by the target model.\n", + "A higher acceptance rate will ensure a higher speedup and therefore it is a very important metric to measure when choosing a draft model.\n", + "\n", + "In this example we will use [CodeGen-350M-Multi](https://huggingface.co/Salesforce/codegen-350M-multi) as a draft model, it has 350M parameters which is ~10x smaller than Phi-2.\n", + "Next, we will prepare our chosen draft model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c996ba6-ef66-42a2-9bb4-2320372e4167", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"Salesforce/codegen-350M-multi\"\n", + "save_name = model_name.split(\"/\")[-1] + \"_openvino_stateless\"\n", + "precision = \"f32\"\n", + "quantization_config = OVWeightQuantizationConfig(\n", + " bits=4,\n", + " sym=False,\n", + " group_size=128,\n", + " ratio=0.8,\n", + ")\n", + "device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb366707-4b99-4c79-a235-d3c887136965", + "metadata": {}, + "outputs": [], + "source": [ + "# Load kwargs\n", + "load_kwargs = {\n", + " \"device\": device,\n", + " \"ov_config\": {\n", + " \"PERFORMANCE_HINT\": \"LATENCY\",\n", + " \"INFERENCE_PRECISION_HINT\": precision,\n", + " \"CACHE_DIR\": os.path.join(save_name, \"model_cache\"), # OpenVINO will use this directory as cache\n", + " },\n", + " \"compile\": False,\n", + " \"quantization_config\": quantization_config\n", + "}\n", + "\n", + "# Check whether the model was already exported\n", + "saved = os.path.exists(save_name)\n", + "\n", + "asst_model = OVModelForCausalLM.from_pretrained(\n", + " model_name if not saved else save_name,\n", + " export=not saved,\n", + " stateful=False,\n", + " **load_kwargs,\n", + ")\n", + "\n", + "# Save the exported model locally\n", + "if not saved:\n", + " asst_model.save_pretrained(save_name)\n", + " tokenizer.save_pretrained(save_name)\n", + "\n", + "asst_model.compile()" + ] + }, + { + "cell_type": "markdown", + "id": "4a95efed-22ce-43a0-af2a-e27500cfa514", + "metadata": {}, + "source": [ + "We will set the configuration of the draft model to predict 3 tokens at each forward step, we found that this setting works quite well in the current setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1466938c-0945-4eb6-a80f-dd165cc5eca1", + "metadata": {}, + "outputs": [], + "source": [ + "asst_model.generation_config.num_assistant_tokens = 3\n", + "asst_model.generation_config.num_assistant_tokens_schedule = \"const\"" + ] + }, + { + "cell_type": "markdown", + "id": "74f6b4c4-4d8a-47fd-8172-6502cc5eef29", + "metadata": {}, + "source": [ + "Next, we will run the same example from before with speculative decoding." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a7e1516-6521-4346-bf85-5890341336f0", + "metadata": {}, + "outputs": [], + "source": [ + "out = stateless_model.generate(\n", + " **inputs,\n", + " max_new_tokens=128,\n", + " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " assistant_model=asst_model,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "dab6669b-f3f1-411e-b4b8-31ead823247f", + "metadata": {}, + "source": [ + "Note that in both cases of AG we presented, the generation result is exactly the same as Phi-2 would have generated it without AG!\n", + "\n", + "Like we mentioned before, the acceptance rate (AR) is a very important metric for choosing a draft.\n", + "We would like to make sure that CodeGen has a good AR with Phi-2.\n", + "For that purpose we implemented an easy utility class that uses the inputs' lengths and window sizes to calculate how many tokens were accepted by the target model at each step and calculate the AR using that information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "480d3e69-0899-4fa8-a85a-cd5a2ce23434", + "metadata": {}, + "outputs": [], + "source": [ + "from functools import wraps\n", + "import numpy as np\n", + "\n", + "\n", + "class AcceptanceRateRecorder:\n", + " def __init__(self, model):\n", + " self.model = model\n", + " self.model_forward = None\n", + " self.model_generate = None\n", + " self.seq_lens = []\n", + " self.win_sizes = []\n", + "\n", + " def __enter__(self):\n", + " # wrap forward method\n", + " if len(self.seq_lens) > 0 or len(self.win_sizes) > 0:\n", + " raise RuntimeError(\"Always use a new instance, don't reuse!\")\n", + " self.model_forward = self.model.forward\n", + " \n", + " @wraps(self.model_forward)\n", + " def forward_wrapper(**kwargs):\n", + " self.seq_lens[-1].append(kwargs.get(\"attention_mask\").shape[-1])\n", + " self.win_sizes[-1].append(kwargs.get(\"input_ids\").shape[-1] - 1)\n", + " return self.model_forward(**kwargs)\n", + " \n", + " self.model.forward = forward_wrapper\n", + " \n", + " # wrap generate method\n", + " self.model_generate = self.model.generate\n", + "\n", + " @wraps(self.model_generate)\n", + " def generate_wrapper(*args, **kwargs):\n", + " self.seq_lens.append([])\n", + " self.win_sizes.append([])\n", + " input_ids = args[0] if len(args) > 0 else kwargs.get(\"input_ids\")\n", + " self.seq_lens[-1].append(input_ids.shape[-1])\n", + " out = self.model_generate(*args, **kwargs)\n", + " self.seq_lens[-1].append(out.shape[-1])\n", + " return out\n", + " self.model.generate = generate_wrapper\n", + " return self\n", + "\n", + " def __exit__(self, type, value, traceback):\n", + " self.model.forward = self.model_forward\n", + " self.model.generate = self.model_generate\n", + " self.model_forward = None\n", + " self.model_generate = None\n", + " # Fix first window size\n", + " for ws, sl in zip(self.win_sizes, self.seq_lens):\n", + " ws[0] -= sl[0] - 1\n", + " # Delete first seq_len, not needed anymore\n", + " self.seq_lens = [sl[1:] for sl in self.seq_lens]\n", + " # Add window size for output to ease calculation later\n", + " for ws, sl in zip(self.win_sizes, self.seq_lens):\n", + " ws.append(0) \n", + "\n", + " def acceptance_rate(self, return_mean=True, normalize=False):\n", + " # ar_per_win = ((cur_seq_len - cur_win_size) - (prev_seq_len - prev_win_size) - 1) / prev_win_size\n", + " ar_per_win = []\n", + " for sl, ws in zip(self.seq_lens, self.win_sizes):\n", + " sl = np.array(sl, dtype=np.float64)\n", + " ws = np.array(ws, dtype=np.float64)\n", + " out_lens = sl - ws\n", + " accepted = (out_lens[1:] - out_lens[:-1] - 1)\n", + " ar_per_win.append(np.divide(accepted, ws[:-1],\n", + " out=np.zeros_like(accepted),where=ws[:-1] != 0))\n", + " ar_per_win = np.hstack(ar_per_win)\n", + " # Normalized AR doesn't take into account windows with size 0\n", + " if normalize:\n", + " ar_per_win = ar_per_win[np.nonzero(np.hstack([ws[:-1] for ws in self.win_sizes]))]\n", + " return np.mean(ar_per_win) if return_mean else ar_per_win" + ] + }, + { + "cell_type": "markdown", + "id": "c35f5e0c-5ed6-4011-a295-80a81fea8b8e", + "metadata": {}, + "source": [ + "Now we can use any dataset for text generation task and measure the AR on that dataset.\n", + "Here we use the [HumanEval](https://huggingface.co/datasets/openai_humaneval) dataset for evaluating code generation.\n", + "We run the model with speculative decoding on 30 samples.\n", + "As you will see, we are getting a very good AR of ~75% for the current configuration.\n", + "\n", + "Note that running this test can take a few minutes depending on the number of samples you are evaluating" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "681a4974-43df-4934-8b61-75c3a92b6df1", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm import tqdm\n", + "from datasets import load_dataset\n", + "\n", + "dataset_name = \"openai_humaneval\"\n", + "dataset_subset_name = None\n", + "field_name = \"prompt\"\n", + "prompt_template = \"\"\"{text}\"\"\"\n", + "dataset = load_dataset(dataset_name, dataset_subset_name, split=\"test\")[field_name]\n", + "samples_number = 30\n", + "with AcceptanceRateRecorder(stateless_model) as ar_recorder:\n", + " for text in tqdm(dataset[:samples_number]):\n", + " tokenized_prompt = tokenizer([prompt_template.format(text=text)], return_tensors='pt')\n", + " stateless_model.generate(\n", + " **tokenized_prompt,\n", + " max_new_tokens=128,\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " assistant_model=asst_model,\n", + " )\n", + "print(f\"Acceptance rate: {ar_recorder.acceptance_rate() * 100:.2f}%\")" + ] + }, { "cell_type": "markdown", "id": "3f8aa25c-de59-4e79-9a1f-c03ec76d206a", "metadata": {}, "source": [ "## Chatbot demo\n", - "We will continue to build a chatbot demo running with Gradio using the model we just exported and quantized.\n", + "We will continue to build a chatbot demo running with Gradio using the models we just exported and quantized.\n", "The chatbot will be rather simple where the user will input a message and the model will reply to the user by generating text using the entire chat history as the input to the model.\n", + "We will also add an option to accelerate inference using speculative decoding with a draft model as we described in the previous section.\n", "\n", "A lot of models that were trained for the chatbot use case have been trained with special tokens to tell the model who is the current speaker and with a special system message. \n", "Phi-2 wasn't trained specifically for the chatbot use case and doesn't have any special tokens either, however, it has seen chats in the training data and therefore is suited for that use case.\n", @@ -328,7 +660,7 @@ " return input_token\n", "\n", "\n", - "def generate(history, temperature, max_new_tokens, top_p, repetition_penalty):\n", + "def generate(history, temperature, max_new_tokens, top_p, repetition_penalty, assisted):\n", " \"\"\"\n", " Generates the assistant's reponse given the chatbot history and generation parameters\n", "\n", @@ -339,6 +671,7 @@ " max_new_tokens: The maximum number of tokens we allow the model to generate as a response.\n", " top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.\n", " repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.\n", + " assisted: boolean parameter to enable/disable assisted generation with speculative decoding.\n", " Yields:\n", " Updated history and generation status.\n", " \"\"\"\n", @@ -354,15 +687,15 @@ " inputs = prepare_history_for_model(history)\n", " input_length = inputs['input_ids'].shape[1]\n", "\n", - " prompt_char = '▌'\n", + " prompt_char = \"▌\"\n", " history[-1][1] = prompt_char\n", - " yield (history, \"Status: Generating...\")\n", + " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", " \n", " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", "\n", " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n", - " stop_str = f'\\nUser:'\n", - " stopping_criteria = StoppingCriteriaList([SuffixCriteria(input_length, [stop_str], tokenizer)])\n", + " stop_str = [\"\\nUser:\", \"\\nAssistant:\", \"\\nRules:\", \"\\nQuestion:\"]\n", + " stopping_criteria = StoppingCriteriaList([SuffixCriteria(input_length, stop_str, tokenizer)])\n", " # Prepare input for generate\n", " generation_config = GenerationConfig(\n", " max_new_tokens=max_new_tokens,\n", @@ -379,7 +712,13 @@ " stopping_criteria=stopping_criteria,\n", " ) | inputs\n", "\n", - " t1 = Thread(target=model.generate, kwargs=generate_kwargs)\n", + " if assisted:\n", + " target_generate = stateless_model.generate\n", + " generate_kwargs[\"assistant_model\"] = asst_model\n", + " else:\n", + " target_generate = model.generate\n", + "\n", + " t1 = Thread(target=target_generate, kwargs=generate_kwargs)\n", " t1.start()\n", "\n", " # Initialize an empty string to store the generated text.\n", @@ -387,17 +726,18 @@ " for new_text in streamer:\n", " partial_text += new_text\n", " history[-1][1] = partial_text + prompt_char\n", - " # We don't yield the generated text until we are sure it is not the stop string\n", - " pos = partial_text.rfind(stop_str)\n", + " for s in stop_str:\n", + " if (pos := partial_text.rfind(s)) != -1:\n", + " break\n", " if pos != -1:\n", " partial_text = partial_text[:pos]\n", " break\n", - " elif is_partial_stop(partial_text, stop_str):\n", + " elif any([is_partial_stop(partial_text, s) for s in stop_str]):\n", " continue\n", - " yield (history, \"Status: Generating...\")\n", + " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", " history[-1][1] = partial_text\n", " generation_time = time.perf_counter() - start\n", - " yield (history, f'Generation time: {generation_time:.2f} sec')" + " yield history, f'Generation time: {generation_time:.2f} sec', *([gr.update(interactive=True)] * 4)" ] }, { @@ -430,6 +770,11 @@ "source": [ "import gradio as gr\n", "\n", + "try:\n", + " demo.close()\n", + "except:\n", + " pass\n", + "\n", "\n", "EXAMPLES = [\n", " [\"What is OpenVINO?\"],\n", @@ -455,14 +800,29 @@ " return ('', history)\n", "\n", "\n", + "def prepare_for_regenerate(history):\n", + " \"\"\"\n", + " Delete last assistant message to prepare for regeneration\n", + "\n", + " Params:\n", + " history: conversation history\n", + " Returns:\n", + " updated history\n", + " \"\"\" \n", + " history[-1][1] = None\n", + " return history\n", + "\n", + "\n", "with gr.Blocks(theme=gr.themes.Soft()) as demo:\n", " gr.Markdown('

Chat with Phi-2 on Meteor Lake iGPU

')\n", " chatbot = gr.Chatbot()\n", " with gr.Row():\n", + " assisted = gr.Checkbox(value=False, label=\"Assisted Generation\", scale=10)\n", " msg = gr.Textbox(placeholder=\"Enter message here...\", show_label=False, autofocus=True, scale=75)\n", - " status = gr.Textbox(\"Status: Idle\", show_label=False, max_lines=1, scale=25)\n", + " status = gr.Textbox(\"Status: Idle\", show_label=False, max_lines=1, scale=15)\n", " with gr.Row():\n", " submit = gr.Button(\"Submit\", variant='primary')\n", + " regenerate = gr.Button(\"Regenerate\")\n", " clear = gr.Button(\"Clear\")\n", " with gr.Accordion(\"Advanced Options:\", open=False):\n", " with gr.Row():\n", @@ -513,12 +873,24 @@ " queue=False,\n", " ).then(\n", " fn=generate,\n", - " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty],\n", - " outputs=[chatbot, status],\n", + " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", + " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", + " concurrency_limit=1,\n", + " queue=True\n", + " )\n", + " regenerate.click(\n", + " fn=prepare_for_regenerate,\n", + " inputs=chatbot,\n", + " outputs=chatbot,\n", + " queue=True,\n", + " concurrency_limit=1\n", + " ).then(\n", + " fn=generate,\n", + " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", + " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", " concurrency_limit=1,\n", " queue=True\n", " )\n", - " \n", " clear.click(fn=lambda: (None, \"Status: Idle\"), inputs=None, outputs=[chatbot, status], queue=False)" ] }, @@ -575,7 +947,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.10.13" } }, "nbformat": 4,