From 3e0d0751741001ff8f17d0a3b8c4a1033ce5c09f Mon Sep 17 00:00:00 2001 From: kopyl Date: Fri, 20 Oct 2023 07:04:02 +0300 Subject: [PATCH] Add a notebook with BakLLaVA inference from Python --- .../bakllava-inference-from-python.ipynb | 256 ++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 notebooks/bakllava-inference-from-python.ipynb diff --git a/notebooks/bakllava-inference-from-python.ipynb b/notebooks/bakllava-inference-from-python.ipynb new file mode 100644 index 0000000..f449a51 --- /dev/null +++ b/notebooks/bakllava-inference-from-python.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "435cb4a0-ab3e-4e73-9817-feb946f112f1", + "metadata": {}, + "source": [ + "### Configure:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c9c3ce6a-1e62-4507-86a8-233f77f525d4", + "metadata": {}, + "outputs": [], + "source": [ + "low_gpu_memory_optimization = True" + ] + }, + { + "cell_type": "markdown", + "id": "c3cf0eaf-bac6-4f3b-a561-88f50bf9a6d8", + "metadata": {}, + "source": [ + "### Install dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65d3fb1d-be7e-4828-a583-53e5ab56a1d5", + "metadata": {}, + "outputs": [], + "source": [ + "!git clone https://github.com/SkunkworksAI/BakLLaVA.git\n", + "%cd BakLLaVA\n", + "!pip install -e .\n", + "!pip uninstall transformers -y\n", + "!pip install transformers==4.34.0" + ] + }, + { + "cell_type": "markdown", + "id": "f1670688-a179-456c-936b-a6957d9c2b46", + "metadata": {}, + "source": [ + "### Run:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f87075a7-c43b-4e63-8d9c-a999eb631d0d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2023-10-20 03:29:22,155] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" + ] + } + ], + "source": [ + "from transformers import AutoConfig, AutoTokenizer\n", + "from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM\n", + "from huggingface_hub import notebook_login\n", + "from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria\n", + "\n", + "from PIL import Image\n", + "import requests\n", + "from io import BytesIO\n", + "\n", + "from llava.conversation import conv_templates, SeparatorStyle\n", + "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n", + "\n", + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "708e45d8-8c4a-4888-bb25-4ebe9a0af993", + "metadata": {}, + "outputs": [], + "source": [ + "notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8393e364-3cfc-4ba8-bd11-0bd183b2a835", + "metadata": {}, + "outputs": [], + "source": [ + "model_path = \"SkunkworksAI/BakLLaVA-1\"\n", + "\n", + "cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n", + "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n", + "if low_gpu_memory_optimization:\n", + " model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, load_in_8bit=True, config=cfg_pretrained)\n", + "else:\n", + " model = LlavaMistralForCausalLM.from_pretrained(model_path, config=cfg_pretrained)\n", + " model.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2ab75125-9a4b-4df1-b314-4b4e96ad8c94", + "metadata": {}, + "outputs": [], + "source": [ + "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n", + "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n", + "if mm_use_im_patch_token:\n", + " tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n", + "if mm_use_im_start_end:\n", + " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n", + "model.resize_token_embeddings(len(tokenizer))\n", + "\n", + "vision_tower = model.get_vision_tower()\n", + "if not vision_tower.is_loaded:\n", + " vision_tower.load_model()\n", + "vision_tower.to(device='cuda', dtype=torch.float16)\n", + "image_processor = vision_tower.image_processor\n", + "\n", + "if hasattr(model.config, \"max_sequence_length\"):\n", + " context_len = model.config.max_sequence_length\n", + "else:\n", + " context_len = 2048" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "64ecf511-ac4f-4bea-b0dd-d8f5611a9d58", + "metadata": {}, + "outputs": [], + "source": [ + "def load_image(image_file):\n", + " if image_file.startswith('http') or image_file.startswith('https'):\n", + " response = requests.get(image_file)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(image_file).convert('RGB')\n", + " return image\n", + "\n", + "\n", + "# image = load_image(\"https://t4.ftcdn.net/jpg/00/97/58/97/360_F_97589769_t45CqXyzjz0KXwoBZT9PRaWGHRk5hQqQ.jpg\")\n", + "image = load_image(\"https://cdn.discordapp.com/attachments/1096822099345145969/1164641565550067852/heart_1.png?ex=6543f3fb&is=65317efb&hm=448cb26e19c141871e776af98077c4c1e97a8f29b96916ab671e5010c00e3625&\")\n", + "\n", + "if low_gpu_memory_optimization:\n", + " image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n", + "else:\n", + " image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "dd3526fa-edb4-41e4-9494-f43f5fec2c45", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"Describe this image\"\n", + "\n", + "if model.config.mm_use_im_start_end:\n", + " query = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + query\n", + "else:\n", + " query = DEFAULT_IMAGE_TOKEN + '\\n' + query\n", + "\n", + "conv = conv_templates[\"llava_v1\"].copy()\n", + "\n", + "conv.append_message(conv.roles[0], query)\n", + "conv.append_message(conv.roles[1], None)\n", + "prompt = conv.get_prompt()\n", + "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", + "\n", + "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", + "keywords = [stop_str]\n", + "stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1383e630-8651-4eb7-898a-7312094f7b9e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The image features a detailed illustration of a human heart, showcasing its various parts and blood vessels. The heart is depicted in full color, with its interior and exterior structures clearly visible.\n", + "\n", + "The heart is surrounded by a network of blood vessels, including arteries and veins. There are at least six visible arteries, some of which are located near the top and bottom of the heart, while others are situated on its right side. Additionally, there are five visible veins, with some located near the top and bottom of the heart, and others on the left side. The arrangement of these blood vessels highlights the complex circulatory system that sustains the heart itself.\n" + ] + } + ], + "source": [ + "with torch.inference_mode():\n", + " output_ids = model.generate(\n", + " input_ids=input_ids,\n", + " images=image_tensor,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " max_new_tokens=1024,\n", + " use_cache=True,\n", + " stopping_criteria=[stopping_criteria])\n", + "\n", + "input_token_len = input_ids.shape[1]\n", + "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n", + "if n_diff_input_output > 0:\n", + " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n", + "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n", + "outputs = outputs.strip()\n", + "if outputs.endswith(stop_str):\n", + " outputs = outputs[:-len(stop_str)]\n", + "outputs = outputs.strip()\n", + "print(outputs)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}