Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a notebook with BakLLaVA inference from Python #6

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 256 additions & 0 deletions notebooks/bakllava-inference-from-python.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "435cb4a0-ab3e-4e73-9817-feb946f112f1",
"metadata": {},
"source": [
"### Configure:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c9c3ce6a-1e62-4507-86a8-233f77f525d4",
"metadata": {},
"outputs": [],
"source": [
"low_gpu_memory_optimization = True"
]
},
{
"cell_type": "markdown",
"id": "c3cf0eaf-bac6-4f3b-a561-88f50bf9a6d8",
"metadata": {},
"source": [
"### Install dependencies:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65d3fb1d-be7e-4828-a583-53e5ab56a1d5",
"metadata": {},
"outputs": [],
"source": [
"!git clone https://github.com/SkunkworksAI/BakLLaVA.git\n",
"%cd BakLLaVA\n",
"!pip install -e .\n",
"!pip uninstall transformers -y\n",
"!pip install transformers==4.34.0"
]
},
{
"cell_type": "markdown",
"id": "f1670688-a179-456c-936b-a6957d9c2b46",
"metadata": {},
"source": [
"### Run:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f87075a7-c43b-4e63-8d9c-a999eb631d0d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[2023-10-20 03:29:22,155] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
}
],
"source": [
"from transformers import AutoConfig, AutoTokenizer\n",
"from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM\n",
"from huggingface_hub import notebook_login\n",
"from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria\n",
"\n",
"from PIL import Image\n",
"import requests\n",
"from io import BytesIO\n",
"\n",
"from llava.conversation import conv_templates, SeparatorStyle\n",
"from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "708e45d8-8c4a-4888-bb25-4ebe9a0af993",
"metadata": {},
"outputs": [],
"source": [
"notebook_login()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8393e364-3cfc-4ba8-bd11-0bd183b2a835",
"metadata": {},
"outputs": [],
"source": [
"model_path = \"SkunkworksAI/BakLLaVA-1\"\n",
"\n",
"cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"if low_gpu_memory_optimization:\n",
" model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, load_in_8bit=True, config=cfg_pretrained)\n",
"else:\n",
" model = LlavaMistralForCausalLM.from_pretrained(model_path, config=cfg_pretrained)\n",
" model.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "2ab75125-9a4b-4df1-b314-4b4e96ad8c94",
"metadata": {},
"outputs": [],
"source": [
"mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
"mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
"if mm_use_im_patch_token:\n",
" tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
"if mm_use_im_start_end:\n",
" tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
"model.resize_token_embeddings(len(tokenizer))\n",
"\n",
"vision_tower = model.get_vision_tower()\n",
"if not vision_tower.is_loaded:\n",
" vision_tower.load_model()\n",
"vision_tower.to(device='cuda', dtype=torch.float16)\n",
"image_processor = vision_tower.image_processor\n",
"\n",
"if hasattr(model.config, \"max_sequence_length\"):\n",
" context_len = model.config.max_sequence_length\n",
"else:\n",
" context_len = 2048"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "64ecf511-ac4f-4bea-b0dd-d8f5611a9d58",
"metadata": {},
"outputs": [],
"source": [
"def load_image(image_file):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(image_file).convert('RGB')\n",
" return image\n",
"\n",
"\n",
"# image = load_image(\"https://t4.ftcdn.net/jpg/00/97/58/97/360_F_97589769_t45CqXyzjz0KXwoBZT9PRaWGHRk5hQqQ.jpg\")\n",
"image = load_image(\"https://cdn.discordapp.com/attachments/1096822099345145969/1164641565550067852/heart_1.png?ex=6543f3fb&is=65317efb&hm=448cb26e19c141871e776af98077c4c1e97a8f29b96916ab671e5010c00e3625&\")\n",
"\n",
"if low_gpu_memory_optimization:\n",
" image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
"else:\n",
" image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "dd3526fa-edb4-41e4-9494-f43f5fec2c45",
"metadata": {},
"outputs": [],
"source": [
"query = \"Describe this image\"\n",
"\n",
"if model.config.mm_use_im_start_end:\n",
" query = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + query\n",
"else:\n",
" query = DEFAULT_IMAGE_TOKEN + '\\n' + query\n",
"\n",
"conv = conv_templates[\"llava_v1\"].copy()\n",
"\n",
"conv.append_message(conv.roles[0], query)\n",
"conv.append_message(conv.roles[1], None)\n",
"prompt = conv.get_prompt()\n",
"input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
"\n",
"stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
"keywords = [stop_str]\n",
"stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1383e630-8651-4eb7-898a-7312094f7b9e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The image features a detailed illustration of a human heart, showcasing its various parts and blood vessels. The heart is depicted in full color, with its interior and exterior structures clearly visible.\n",
"\n",
"The heart is surrounded by a network of blood vessels, including arteries and veins. There are at least six visible arteries, some of which are located near the top and bottom of the heart, while others are situated on its right side. Additionally, there are five visible veins, with some located near the top and bottom of the heart, and others on the left side. The arrangement of these blood vessels highlights the complex circulatory system that sustains the heart itself.\n"
]
}
],
"source": [
"with torch.inference_mode():\n",
" output_ids = model.generate(\n",
" input_ids=input_ids,\n",
" images=image_tensor,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=1024,\n",
" use_cache=True,\n",
" stopping_criteria=[stopping_criteria])\n",
"\n",
"input_token_len = input_ids.shape[1]\n",
"n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
"if n_diff_input_output > 0:\n",
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
"outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
"outputs = outputs.strip()\n",
"if outputs.endswith(stop_str):\n",
" outputs = outputs[:-len(stop_str)]\n",
"outputs = outputs.strip()\n",
"print(outputs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}