Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

captioning + dataset preparation + inference + improvements #34

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
.PHONY: quality style

check_dirs := training tests
check_dirs := training tests video_recaptioning

quality:
ruff check $(check_dirs)
ruff format --check $(check_dirs) setup.py
ruff format --check $(check_dirs)

style:
ruff check $(check_dirs) --fix
Expand Down
56 changes: 56 additions & 0 deletions video_recaptioning/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import decord

from torch.utils.data import Dataset
from PIL import Image
from typing import Tuple, List
import os
import io
import base64

class VideoDataset(Dataset):
def __init__(self, root_video_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4")):
self.root_video_dir = root_video_dir
self.max_num_frames = max_num_frames

video_files = [
os.path.join(root_video_dir, f) for f in os.listdir(root_video_dir) if f.endswith(video_extensions)
]
self.video_files = sorted(video_files)

def __len__(self) -> int:
return len(self.video_files)

def __getitem__(self, index: int) -> List[Image.Image]:
video_path = self.video_files[index]
return self.load_video(video_path)

def load_video(self, path: str) -> List[Image.Image]:
video_reader = decord.VideoReader(uri=path)
base_name = os.path.basename(path).split(".")[0]

video_frames = [
Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader))
][:self.max_num_frames]
return {"video": [self.encode_image(frame) for frame in video_frames], "video_name": base_name}

def encode_image(self, image):
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
image_bytes = buffered.getvalue()
return base64.b64encode(image_bytes).decode("utf-8")

if __name__ == "__main__":
from huggingface_hub import snapshot_download
import tempfile

with tempfile.TemporaryDirectory() as tmpdirname:
video_root_dir = snapshot_download(
repo_id="Wild-Heart/Disney-VideoGeneration-Dataset", repo_type="dataset", local_dir=tmpdirname
)

dataset = VideoDataset(os.path.join(video_root_dir, "videos"), max_num_frames=16)
print(len(dataset))

for item in dataset:
print(len(item["video"]))
break
10 changes: 10 additions & 0 deletions video_recaptioning/launch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

export CUDA_VISIBLE_DEVICES=1,2

python recaption.py --root_dir="video-dataset-disney/videos" \
--output_dir="video-dataset-disney" \
--max_num_frames=8 --max_tokens=120 \
--num_data_workers=4 --batch_size=2 \
--prompt="Describe this set of frames. Consider the frames to be a part of the same video." \
--num_artifact_workers=4
110 changes: 110 additions & 0 deletions video_recaptioning/recaption.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Needs `vllm` to be installed from the `main`.
"""

from vllm import LLM, SamplingParams
import queue
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from concurrent.futures import ThreadPoolExecutor
import torch
import fire
import os

from dataset import VideoDataset

def save_results(output_queue, output_dir):
while True:
try:
item = output_queue.get(timeout=5)
if item is None:
break

video_names, outputs = item
outputs = [o.outputs[0].text for o in outputs]

for i, pred_caption in enumerate(outputs):
with open(os.path.join(output_dir, f"{video_names[i]}_caption.txt"), "w") as f:
f.write(pred_caption)

except queue.Empty:
continue

def create_messages(batch, prompt: str):
messages = []
for i, video in enumerate(batch["videos"]):
messages.append({"role": "user", "content": []})
messages[i]["content"].append({"type": "text", "text": prompt})
for j in range(len(video)):
new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{video[j]}"}}
messages[i]["content"].append(new_image)
return messages

def collate_fn(batch):
inputs = {
"videos": [sample["video"] for sample in batch],
"video_names": [sample["video_name"] for sample in batch]
}
return inputs

def prepare_dataloader(video_root_dir, max_num_frames, num_data_workers, batch_size):
dataset = VideoDataset(video_root_dir, max_num_frames=max_num_frames)

rank = 0
world_size = 1
if torch.distributed.is_initialized():
group = torch.distributed.group.WORLD
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False)

# Create DataLoader
dataloader = DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=collate_fn,
num_workers=num_data_workers,
pin_memory=True
)
return dataloader

def load_model(max_num_frames: int, max_tokens: int):
vllm_engine = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": max_num_frames})
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
sampling_params = SamplingParams(max_tokens=max_tokens)
return vllm_engine, sampling_params


def main(
root_dir, prompt, output_dir, max_num_frames, max_tokens, num_data_workers=4, batch_size=8, num_artifact_workers=4
):
max_allowed_imgs_per_req = batch_size * max_num_frames
vllm_engine, sampling_params = load_model(
max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens
)
dataloader = prepare_dataloader(
video_root_dir=root_dir, max_num_frames=max_num_frames,
num_data_workers=num_data_workers, batch_size=batch_size
)

output_queue = queue.Queue()
save_thread = ThreadPoolExecutor(max_workers=num_artifact_workers)
os.makedirs(output_dir, exist_ok=True)
save_future = save_thread.submit(save_results, output_queue, output_dir)

try:
for batch in dataloader:
messages = create_messages(batch, prompt=prompt)
outputs = vllm_engine.chat(messages, sampling_params)
output_queue.put((batch["video_names"], outputs))

finally:
output_queue.put(None)
save_thread.shutdown(wait=True)

save_future.result()
print("All processes completed. Caption generation and saving done.")


if __name__ == "__main__":
fire.Fire(main)