-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvlm_processor.py
110 lines (99 loc) · 3.81 KB
/
vlm_processor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import cv2
import torch
from PIL import Image
import io
class VLM:
def __init__(self, use_flash_attention: bool = False, device: str = "mps"):
"""Initialize the VLM processor with optional flash attention."""
print("Initializing VLM")
model_name = "Qwen/Qwen2-VL-2B-Instruct"
if use_flash_attention:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
self.model.to(device)
self.processor = AutoProcessor.from_pretrained(model_name)
print("VLM initialized")
def capture_image(self) -> Image.Image:
"""Capture an image from the webcam and return as PIL Image."""
cap = cv2.VideoCapture(0)
if not cap.isOpened(): raise RuntimeError("Could not access webcam")
# Wait for camera to initialize and capture non-black frames
max_attempts = 10
for _ in range(max_attempts):
ret, frame = cap.read()
if not ret: continue
# Check if image is mostly black
# Average pixel value threshold
if cv2.mean(frame)[0] < 5: continue
# Valid frame captured
cap.release()
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(rgb_frame)
# Release camera if no valid frame captured
cap.release()
raise RuntimeError("Failed to capture non-black image after multiple attempts")
def get_image_description(self, image: Image.Image) -> str:
"""Get description for a given PIL Image."""
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": "Describe this image."},
],
}
]
# Prepare for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
# Generate response
generated_ids = self.model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0]
def get_image_and_description(self) -> str:
"""Capture an image from webcam and get its description."""
try:
print("Capturing image...")
image = self.capture_image()
print("Getting image description...")
desc = self.get_image_description(image)
print(desc)
return desc
except Exception as e:
return f"Error processing image: {str(e)}"
if __name__ == "__main__":
vlm = VLM()
print(vlm.get_image_and_description())