Skip to content

Commit 2287882

Browse files
Add init example for omni mode
1 parent 78cca0a commit 2287882

File tree

1 file changed

+135
-0
lines changed
  • python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6

1 file changed

+135
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import os
18+
import time
19+
import math
20+
import torch
21+
import librosa
22+
import argparse
23+
import numpy as np
24+
from PIL import Image
25+
from moviepy import VideoFileClip
26+
from transformers import AutoTokenizer
27+
from ipex_llm.transformers import AutoModel
28+
29+
30+
31+
# The video chunk function is adpated from https://huggingface.co/openbmb/MiniCPM-o-2_6#chat-inference
32+
def get_video_chunk_content(video_path, temp_audio_name, flatten=True):
33+
video = VideoFileClip(video_path)
34+
print('video_duration:', video.duration)
35+
36+
with open(temp_audio_name, 'wb') as temp_audio_file:
37+
temp_audio_file_path = temp_audio_file.name
38+
video.audio.write_audiofile(temp_audio_file_path, codec="pcm_s16le", fps=16000)
39+
audio_np, sr = librosa.load(temp_audio_file_path, sr=16000, mono=True)
40+
num_units = math.ceil(video.duration)
41+
42+
# 1 frame + 1s audio chunk
43+
contents= []
44+
for i in range(num_units):
45+
frame = video.get_frame(i+1)
46+
image = Image.fromarray((frame).astype(np.uint8))
47+
audio = audio_np[sr*i:sr*(i+1)]
48+
if flatten:
49+
contents.extend(["<unit>", image, audio])
50+
else:
51+
contents.append(["<unit>", image, audio])
52+
53+
return contents
54+
55+
56+
if __name__ == '__main__':
57+
parser = argparse.ArgumentParser(description='Chat with MiniCPM-o-2_6 in Omni mode')
58+
parser.add_argument('--repo-id-or-model-path', type=str,
59+
help='The Hugging Face or ModelScope repo id for the MiniCPM-o-2_6 model to be downloaded'
60+
', or the path to the checkpoint folder')
61+
parser.add_argument('--video-path', type=str, required=True,
62+
help='The path to the video, which the model uses to conduct inference '
63+
'based on its images and audio.')
64+
parser.add_argument('--n-predict', type=int, default=32,
65+
help='Max tokens to predict')
66+
67+
args = parser.parse_args()
68+
69+
model_path = args.repo_id_or_model_path
70+
video_path = args.video_path
71+
72+
# Load model in 4 bit,
73+
# which convert the relevant layers in the model into INT4 format
74+
model = AutoModel.from_pretrained(model_path,
75+
load_in_low_bit="sym_int4",
76+
optimize_model=True,
77+
trust_remote_code=True,
78+
attn_implementation='sdpa',
79+
use_cache=True,
80+
init_vision=True,
81+
init_audio=True,
82+
init_tts=False,
83+
modules_to_not_convert=["vpm", "resampler"])
84+
85+
model = model.half().to('xpu')
86+
87+
tokenizer = AutoTokenizer.from_pretrained(model_path,
88+
trust_remote_code=True)
89+
90+
91+
# The following code for generation is adapted from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct#quickstart
92+
temp_audio_name = "temp_audio.wav"
93+
contents = get_video_chunk_content(video_path, temp_audio_name)
94+
messages = [{"role":"user", "content": contents}]
95+
96+
import os
97+
if os.path.exists(temp_audio_name):
98+
os.remove(temp_audio_name)
99+
100+
with torch.inference_mode():
101+
# ipex_llm model needs a warmup, then inference time can be accurate
102+
model.chat(
103+
msgs=messages,
104+
tokenizer=tokenizer,
105+
sampling=True,
106+
temperature=0.5,
107+
max_new_tokens=args.n_predict,
108+
omni_input=True, # need to set omni_input=True when omni inference
109+
use_tts_template=False,
110+
generate_audio=False,
111+
max_slice_nums=1,
112+
use_image_id=False,
113+
)
114+
115+
st = time.time()
116+
response = model.chat(
117+
msgs=messages,
118+
tokenizer=tokenizer,
119+
sampling=True,
120+
temperature=0.5,
121+
max_new_tokens=args.n_predict,
122+
omni_input=True, # need to set omni_input=True when omni inference
123+
use_tts_template=False,
124+
generate_audio=False,
125+
max_slice_nums=1,
126+
use_image_id=False,
127+
)
128+
torch.xpu.synchronize()
129+
end = time.time()
130+
131+
print(f'Inference time: {end-st} s')
132+
print('-'*20, 'Input Video Path', '-'*20)
133+
print(video_path)
134+
print('-'*20, 'Chat Output', '-'*20)
135+
print(response)

0 commit comments

Comments
 (0)