You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I try using llava to inference the pllava model, the result is really hard to debug,
the output:
USER: bfxs
/data/miniconda3/envs/env-3.9.2/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:520: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.
warnings.warn(
ASSISTANT: hidden_states: torch.Size([16, 576, 5120])
input torch.Size([16, 576, 5120]) num_videos 1 frame_shape [24, 24]
Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.
USER:
bfxs ASSISTANT:
AI: Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.
USER:
bfxs ASSISTANT:
:
USER:
Is simply repeat the input texts.
def load_pllava(
repo_id,
num_frames,
use_lora=False,
weight_dir=None,
lora_alpha=32,
use_multi_gpus=False,
pooling_shape=(16, 12, 12),
):
kwargs = {
"num_frames": num_frames,
}
# print("===============>pooling_shape", pooling_shape)
if num_frames == 0:
kwargs.update(
pooling_shape=(0, 12, 12)
) # produce a bug if ever usen the pooling projector
config = PllavaConfig.from_pretrained(
repo_id if not use_lora else weight_dir,
pooling_shape=pooling_shape,
**kwargs,
)
with torch.no_grad():
model = PllavaForConditionalGeneration.from_pretrained(
repo_id, config=config, torch_dtype=dtype
)
try:
processor = PllavaProcessor.from_pretrained(repo_id)
except Exception as e:
processor = PllavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
# load weights
if weight_dir is not None:
state_dict = {}
save_fnames = os.listdir(weight_dir)
if "model.safetensors" in save_fnames:
use_full = False
for fn in save_fnames:
if fn.startswith("model-0"):
use_full = True
break
else:
use_full = True
if not use_full:
print("Loading weight from", weight_dir, "model.safetensors")
with safe_open(
f"{weight_dir}/model.safetensors", framework="pt", device="cpu"
) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
else:
print("Loading weight from", weight_dir)
for fn in save_fnames:
if fn.startswith("model-0"):
with safe_open(
f"{weight_dir}/{fn}", framework="pt", device="cpu"
) as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if "model" in state_dict.keys():
msg = model.load_state_dict(state_dict["model"], strict=False)
else:
msg = model.load_state_dict(state_dict, strict=False)
print(msg)
model.to("cuda")
model.eval()
return model, processor
def pllava_answer(
conv,
model,
processor,
img_list,
do_sample=True,
max_new_tokens=200,
num_beams=1,
min_length=1,
top_p=0.9,
repetition_penalty=1.0,
length_penalty=1,
temperature=1.0,
stop_criteria_keywords=None,
print_res=False,
):
# torch.cuda.empty_cache()
prompt = conv.get_prompt()
inputs = processor(text=prompt, images=img_list, return_tensors="pt")
if inputs["pixel_values"] is None:
inputs.pop("pixel_values")
inputs = inputs.to(model.device)
# set up stopping criteria
if stop_criteria_keywords is not None:
stopping_criteria = [
KeywordsStoppingCriteria(
stop_criteria_keywords, processor.tokenizer, inputs["input_ids"]
)
]
else:
stopping_criteria = None
with torch.no_grad():
output_token = model.generate(
**inputs,
media_type="video",
# media_type="image",
do_sample=do_sample,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
temperature=temperature,
stopping_criteria=stopping_criteria,
)
output_text = processor.batch_decode(
output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print(output_text)
if print_res: # debug usage
# print("### PROMPTING LM WITH: ", prompt)
print("AI: ", output_text)
if conv.roles[-1] == "<|im_start|>assistant\n":
split_tag = "<|im_start|> assistant\n"
else:
split_tag = conv.roles[-1]
output_text = output_text.split(split_tag)[-1]
ending = conv.sep if isinstance(conv.sep, str) else conv.sep[1]
output_text = output_text.removesuffix(ending).strip()
conv.messages[-1][1] = output_text
return output_text, conv
def main(args):
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, processor = load_model_auto(args.model_path, dtype=dtype)
model = model.to(dtype)
print(f"using dtype: {dtype}")
image = load_image(args.image_file)
image_tensor = (
processor(images=image, return_tensors="pt")["pixel_values"]
.to(model.device)
.to(dtype)
)
while True:
conv = conv_templates["vicuna_v1"].copy()
conv.system = "Carefully watch the video and pay attention to the cause and sequence of events, the detail and movement of objects, and the action and pose of persons. Based on your observations, select the best option that accurately addresses the question.\n"
try:
inp = input(f"{conv.roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
if is_image(inp):
image = load_image(inp)
image_tensor = (
image_processor(images=image, return_tensors="pt")["pixel_values"]
.to(model.device)
.to(dtype)
)
# print('updated new image')
# clear conv history
conv.messages = []
print("Updated image, start new chat session.")
continue
print(f"{conv.roles[1]}: ", end="")
# conv.user_query("Describe the video in details.", is_mm=True)
if image is not None:
# first message
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
conv.append_message(conv.roles[0], inp)
# image = None
else:
# later messages
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
# prompt = conv.get_prompt()
img_list = [image] * 16
llm_response, conv = pllava_answer(
conv=conv,
model=model,
processor=processor,
do_sample=False,
img_list=img_list,
max_new_tokens=256,
print_res=True,
)
print(llm_response)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path", type=str, default="checkpoints/llava-qwen-4b-finetune/"
)
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default="images/kobe.jpg")
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
Above is the simlest inference code which borrows from pllava, but the result is always not right.
I using a single image repeat to 16 frames feed into as a single video.
Any help?
The text was updated successfully, but these errors were encountered:
Hi, I try using llava to inference the pllava model, the result is really hard to debug,
the output:
Is simply repeat the input texts.
Above is the simlest inference code which borrows from pllava, but the result is always not right.
I using a single image repeat to 16 frames feed into as a single video.
Any help?
The text was updated successfully, but these errors were encountered: