From 247a6bef76c0341e30524508be13f3b85f6b5e31 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Oct 2024 12:44:22 +0530 Subject: [PATCH 01/14] add dataset --- Makefile | 2 +- video_recaptioning/dataset.py | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 1 deletion(-) create mode 100644 video_recaptioning/dataset.py diff --git a/Makefile b/Makefile index 78644f9d..b3936355 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ check_dirs := training tests quality: ruff check $(check_dirs) - ruff format --check $(check_dirs) setup.py + ruff format --check $(check_dirs) style: ruff check $(check_dirs) --fix diff --git a/video_recaptioning/dataset.py b/video_recaptioning/dataset.py new file mode 100644 index 00000000..3527a7b5 --- /dev/null +++ b/video_recaptioning/dataset.py @@ -0,0 +1,47 @@ +import decord + +from torch.utils.data import Dataset +from PIL import Image +from typing import Tuple, List +import os + +class VideoDataset(Dataset): + def __init__(self, root_video_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4")): + self.root_video_dir = root_video_dir + self.max_num_frames = max_num_frames + + video_files = [ + os.path.join(root_video_dir, f) for f in os.listdir(root_video_dir) if f.endswith(video_extensions) + ] + self.video_files = sorted(video_files) + + def __len__(self) -> int: + return len(self.video_files) + + def __getitem__(self, index: int) -> List[Image.Image]: + video_path = self.video_files[index] + return self.load_video(video_path) + + def load_video(self, path: str) -> List[Image.Image]: + video_reader = decord.VideoReader(uri=path) + + video_frames = [ + Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader)) + ][:self.max_num_frames] + return video_frames + +if __name__ == "__main__": + from huggingface_hub import snapshot_download + import tempfile + + with tempfile.TemporaryDirectory() as tmpdirname: + video_root_dir = snapshot_download( + repo_id="Wild-Heart/Disney-VideoGeneration-Dataset", repo_type="dataset", local_dir=tmpdirname + ) + + dataset = VideoDataset(os.path.join(video_root_dir, "videos"), max_num_frames=16) + print(len(dataset)) + + for item in dataset: + print(len(item)) + break \ No newline at end of file From e12cc03735df365501374ae97f9d9d78d5834249 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Oct 2024 15:15:29 +0530 Subject: [PATCH 02/14] add captioner. --- video_recaptioning/dataset.py | 13 +++- video_recaptioning/launch.sh | 10 +++ video_recaptioning/recaption.py | 110 ++++++++++++++++++++++++++++++++ 3 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 video_recaptioning/launch.sh create mode 100644 video_recaptioning/recaption.py diff --git a/video_recaptioning/dataset.py b/video_recaptioning/dataset.py index 3527a7b5..be5eff96 100644 --- a/video_recaptioning/dataset.py +++ b/video_recaptioning/dataset.py @@ -4,6 +4,8 @@ from PIL import Image from typing import Tuple, List import os +import io +import base64 class VideoDataset(Dataset): def __init__(self, root_video_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4")): @@ -24,11 +26,18 @@ def __getitem__(self, index: int) -> List[Image.Image]: def load_video(self, path: str) -> List[Image.Image]: video_reader = decord.VideoReader(uri=path) + base_name = os.path.basename(path).split(".")[0] video_frames = [ Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader)) ][:self.max_num_frames] - return video_frames + return {"video": [self.encode_image(frame) for frame in video_frames], "video_name": base_name} + + def encode_image(self, image): + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + image_bytes = buffered.getvalue() + return base64.b64encode(image_bytes).decode("utf-8") if __name__ == "__main__": from huggingface_hub import snapshot_download @@ -43,5 +52,5 @@ def load_video(self, path: str) -> List[Image.Image]: print(len(dataset)) for item in dataset: - print(len(item)) + print(len(item["video"])) break \ No newline at end of file diff --git a/video_recaptioning/launch.sh b/video_recaptioning/launch.sh new file mode 100644 index 00000000..c93563ec --- /dev/null +++ b/video_recaptioning/launch.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=1,2 + +python recaption.py --root_dir="video-dataset-disney/videos" \ + --output_dir="video-dataset-disney" \ + --max_num_frames=8 --max_tokens=120 \ + --num_data_workers=4 --batch_size=2 \ + --prompt="Describe this set of frames. Consider the frames to be a part of the same video." \ + --num_artifact_workers=4 \ No newline at end of file diff --git a/video_recaptioning/recaption.py b/video_recaptioning/recaption.py new file mode 100644 index 00000000..0bb5b0d1 --- /dev/null +++ b/video_recaptioning/recaption.py @@ -0,0 +1,110 @@ +""" +Needs `vllm` to be installed from the `main`. +""" + +from vllm import LLM, SamplingParams +import queue +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from concurrent.futures import ThreadPoolExecutor +import torch +import fire +import os + +from dataset import VideoDataset + +def save_results(output_queue, output_dir): + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + + video_names, outputs = item + outputs = [o.outputs[0].text for o in outputs] + + for i, pred_caption in enumerate(outputs): + with open(os.path.join(output_dir, f"{video_names[i]}_caption.txt"), "w") as f: + f.write(pred_caption) + + except queue.Empty: + continue + +def create_messages(batch, prompt: str): + messages = [] + for i, video in enumerate(batch["videos"]): + messages.append({"role": "user", "content": []}) + messages[i]["content"].append({"type": "text", "text": prompt}) + for j in range(len(video)): + new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{video[j]}"}} + messages[i]["content"].append(new_image) + return messages + +def collate_fn(batch): + inputs = { + "videos": [sample["video"] for sample in batch], + "video_names": [sample["video_name"] for sample in batch] + } + return inputs + +def prepare_dataloader(video_root_dir, max_num_frames, num_data_workers, batch_size): + dataset = VideoDataset(video_root_dir, max_num_frames=max_num_frames) + + rank = 0 + world_size = 1 + if torch.distributed.is_initialized(): + group = torch.distributed.group.WORLD + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False) + + # Create DataLoader + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + collate_fn=collate_fn, + num_workers=num_data_workers, + pin_memory=True + ) + return dataloader + +def load_model(max_num_frames: int, max_tokens: int): + vllm_engine = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": max_num_frames}) + sampling_params = SamplingParams(max_tokens=max_tokens) + return vllm_engine, sampling_params + + +def main( + root_dir, prompt, output_dir, max_num_frames, max_tokens, num_data_workers=4, batch_size=8, num_artifact_workers=4 + ): + max_allowed_imgs_per_req = batch_size * max_num_frames + vllm_engine, sampling_params = load_model( + max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens + ) + dataloader = prepare_dataloader( + video_root_dir=root_dir, max_num_frames=max_num_frames, + num_data_workers=num_data_workers, batch_size=batch_size + ) + + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=num_artifact_workers) + os.makedirs(output_dir, exist_ok=True) + save_future = save_thread.submit(save_results, output_queue, output_dir) + + try: + for batch in dataloader: + messages = create_messages(batch, prompt=prompt) + outputs = vllm_engine.chat(messages, sampling_params) + output_queue.put((batch["video_names"], outputs)) + + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) + + save_future.result() + print("All processes completed. Caption generation and saving done.") + + +if __name__ == "__main__": + fire.Fire(main) \ No newline at end of file From e16ea4cd377d63b2828f03fd506fe37e97e8857d Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 15 Oct 2024 15:17:57 +0530 Subject: [PATCH 03/14] makefile --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index b3936355..eefa9cbd 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: quality style -check_dirs := training tests +check_dirs := training tests video_recaptioning quality: ruff check $(check_dirs) From 132e5063a78b9e8fa8319b30e0275eae77febd61 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 17 Oct 2024 15:51:55 +0530 Subject: [PATCH 04/14] fixes --- video_recaptioning/dataset.py | 26 +++++++++++-- video_recaptioning/launch.sh | 3 +- video_recaptioning/recaption.py | 69 +++++++++++++++++++-------------- 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/video_recaptioning/dataset.py b/video_recaptioning/dataset.py index be5eff96..f0b8af37 100644 --- a/video_recaptioning/dataset.py +++ b/video_recaptioning/dataset.py @@ -6,16 +6,34 @@ import os import io import base64 +import sys class VideoDataset(Dataset): - def __init__(self, root_video_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4")): + def __init__( + self, root_video_dir: str, output_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4") + ): self.root_video_dir = root_video_dir self.max_num_frames = max_num_frames - video_files = [ + # Filter out existing captions. + video_files = { os.path.join(root_video_dir, f) for f in os.listdir(root_video_dir) if f.endswith(video_extensions) - ] - self.video_files = sorted(video_files) + } + existing_caption_basenames = {os.path.splitext(f)[0] for f in os.listdir(output_dir) if "_caption.txt" in f} + if existing_caption_basenames: + if len(existing_caption_basenames) == len(video_files): + sys.exit("It seems like all the input videos have been already captioned. So, we're exiting the program.") + filtered_video_files = [ + f for f in video_files if os.path.splitext(os.path.basename(f))[0] + "_caption" not in existing_caption_basenames + ] + if len(video_files) > len(filtered_video_files): + diff = len(video_files) - len(filtered_video_files) + print(f"Found existing captions for {diff} videos. Will skip them.") + + self.video_files = sorted(filtered_video_files) + else: + self.video_files = sorted(video_files) + print(f"Total videos found: {len(self.video_files)}.") def __len__(self) -> int: return len(self.video_files) diff --git a/video_recaptioning/launch.sh b/video_recaptioning/launch.sh index c93563ec..2094e0bb 100644 --- a/video_recaptioning/launch.sh +++ b/video_recaptioning/launch.sh @@ -1,9 +1,10 @@ #!/bin/bash -export CUDA_VISIBLE_DEVICES=1,2 +export CUDA_VISIBLE_DEVICES=0,1 python recaption.py --root_dir="video-dataset-disney/videos" \ --output_dir="video-dataset-disney" \ + --num_devices=2 \ --max_num_frames=8 --max_tokens=120 \ --num_data_workers=4 --batch_size=2 \ --prompt="Describe this set of frames. Consider the frames to be a part of the same video." \ diff --git a/video_recaptioning/recaption.py b/video_recaptioning/recaption.py index 0bb5b0d1..346a99ce 100644 --- a/video_recaptioning/recaption.py +++ b/video_recaptioning/recaption.py @@ -5,9 +5,7 @@ from vllm import LLM, SamplingParams import queue from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler from concurrent.futures import ThreadPoolExecutor -import torch import fire import os @@ -30,15 +28,17 @@ def save_results(output_queue, output_dir): except queue.Empty: continue -def create_messages(batch, prompt: str): - messages = [] +def create_conversations(batch, prompt: str): + conversations = [] for i, video in enumerate(batch["videos"]): - messages.append({"role": "user", "content": []}) - messages[i]["content"].append({"type": "text", "text": prompt}) + content = [] + content.append({"type": "text", "text": "Describe this set of frames. Consider the frames to be a part of the same video."}) for j in range(len(video)): new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{video[j]}"}} - messages[i]["content"].append(new_image) - return messages + content.append(new_image) + message = {"role": "user", "content": content} + conversations.append([message]) + return conversations def collate_fn(batch): inputs = { @@ -47,44 +47,55 @@ def collate_fn(batch): } return inputs -def prepare_dataloader(video_root_dir, max_num_frames, num_data_workers, batch_size): - dataset = VideoDataset(video_root_dir, max_num_frames=max_num_frames) - - rank = 0 - world_size = 1 - if torch.distributed.is_initialized(): - group = torch.distributed.group.WORLD - rank = torch.distributed.get_rank(group=group) - world_size = torch.distributed.get_world_size(group=group) - sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False) +def prepare_dataloader( + video_root_dir, output_dir, video_extensions, max_num_frames, num_data_workers, batch_size + ): + dataset = VideoDataset( + video_root_dir, output_dir=output_dir, max_num_frames=max_num_frames, video_extensions=video_extensions + ) # Create DataLoader dataloader = DataLoader( dataset, batch_size=batch_size, - sampler=sampler, collate_fn=collate_fn, num_workers=num_data_workers, - pin_memory=True + pin_memory=True, + persistent_workers=True ) return dataloader -def load_model(max_num_frames: int, max_tokens: int): - vllm_engine = LLM("Qwen/Qwen2-VL-2B-Instruct", limit_mm_per_prompt={"image": max_num_frames}) +def load_model(max_num_frames: int, max_tokens: int, num_devices: int): + vllm_engine = LLM( + "Qwen/Qwen2-VL-2B-Instruct", tensor_parallel_size=num_devices, limit_mm_per_prompt={"image": max_num_frames} + ) sampling_params = SamplingParams(max_tokens=max_tokens) return vllm_engine, sampling_params def main( - root_dir, prompt, output_dir, max_num_frames, max_tokens, num_data_workers=4, batch_size=8, num_artifact_workers=4 + root_dir: str, + prompt: str, + output_dir: str, + num_devices: int, + max_num_frames: int, + max_tokens: int, + video_extensions: tuple = (".mp4"), + num_data_workers:int = 4, + batch_size:int = 8, + num_artifact_workers:int = 4 ): max_allowed_imgs_per_req = batch_size * max_num_frames vllm_engine, sampling_params = load_model( - max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens + max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens, num_devices=num_devices ) dataloader = prepare_dataloader( - video_root_dir=root_dir, max_num_frames=max_num_frames, - num_data_workers=num_data_workers, batch_size=batch_size + video_root_dir=root_dir, + output_dir=output_dir, + video_extensions=video_extensions, + max_num_frames=max_num_frames, + num_data_workers=num_data_workers, + batch_size=batch_size ) output_queue = queue.Queue() @@ -93,9 +104,9 @@ def main( save_future = save_thread.submit(save_results, output_queue, output_dir) try: - for batch in dataloader: - messages = create_messages(batch, prompt=prompt) - outputs = vllm_engine.chat(messages, sampling_params) + for idx, batch in enumerate(dataloader): + conversations = create_conversations(batch, prompt=prompt) + outputs = vllm_engine.chat(conversations, sampling_params) output_queue.put((batch["video_names"], outputs)) finally: From 97a0af8abcee0572f4515570ad4ec89586a84670 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 17 Oct 2024 13:48:41 +0200 Subject: [PATCH 05/14] add download_dir param --- video_recaptioning/dataset.py | 5 ++++- video_recaptioning/launch.sh | 36 ++++++++++++++++++++++++++------- video_recaptioning/recaption.py | 12 ++++++----- 3 files changed, 40 insertions(+), 13 deletions(-) mode change 100644 => 100755 video_recaptioning/launch.sh diff --git a/video_recaptioning/dataset.py b/video_recaptioning/dataset.py index f0b8af37..4322e9c6 100644 --- a/video_recaptioning/dataset.py +++ b/video_recaptioning/dataset.py @@ -19,7 +19,10 @@ def __init__( video_files = { os.path.join(root_video_dir, f) for f in os.listdir(root_video_dir) if f.endswith(video_extensions) } - existing_caption_basenames = {os.path.splitext(f)[0] for f in os.listdir(output_dir) if "_caption.txt" in f} + if os.path.isdir(output_dir): + existing_caption_basenames = {os.path.splitext(f)[0] for f in os.listdir(output_dir) if "_caption.txt" in f} + else: + existing_caption_basenames = None if existing_caption_basenames: if len(existing_caption_basenames) == len(video_files): sys.exit("It seems like all the input videos have been already captioned. So, we're exiting the program.") diff --git a/video_recaptioning/launch.sh b/video_recaptioning/launch.sh old mode 100644 new mode 100755 index 2094e0bb..aafb7b2e --- a/video_recaptioning/launch.sh +++ b/video_recaptioning/launch.sh @@ -2,10 +2,32 @@ export CUDA_VISIBLE_DEVICES=0,1 -python recaption.py --root_dir="video-dataset-disney/videos" \ - --output_dir="video-dataset-disney" \ - --num_devices=2 \ - --max_num_frames=8 --max_tokens=120 \ - --num_data_workers=4 --batch_size=2 \ - --prompt="Describe this set of frames. Consider the frames to be a part of the same video." \ - --num_artifact_workers=4 \ No newline at end of file +# Path to where vLLM models should be downloaded +DOWNLOAD_DIR="/path/to/download/dir" + +# Path to where video files are located +ROOT_DIR="/path/to/video/files" + +# Path to where captions should be stored +OUTPUT_DIR="/path/to/save/captions" + +# Other configurations +MAX_FRAMES=8 +MAX_TOKENS=120 +BATCH_SIZE=2 +NUM_DATA_WORKERS=4 +NUM_ARTIFACT_WORKERS=4 + +PROMPT="Please describe the content of this video in as much detail as possible, including the objects, scenery, animals, characters, and camera movements within the video. Do not include '\n' in your response. Please start the description with the video content directly. Please describe the content of the video and the changes that occur, in chronological order." + +python recaption.py \ + --root_dir $ROOT_DIR \ + --output_dir $OUTPUT_DIR \ + --num_devices 1 \ + --max_num_frames $MAX_FRAMES \ + --max_tokens $MAX_TOKENS \ + --num_data_workers $NUM_DATA_WORKERS \ + --batch_size $BATCH_SIZE \ + --prompt $PROMPT \ + --num_artifact_workers $NUM_ARTIFACT_WORKERS \ + --download_dir $DOWNLOAD_DIR diff --git a/video_recaptioning/recaption.py b/video_recaptioning/recaption.py index 346a99ce..8923eb9f 100644 --- a/video_recaptioning/recaption.py +++ b/video_recaptioning/recaption.py @@ -2,6 +2,7 @@ Needs `vllm` to be installed from the `main`. """ +from typing import Optional from vllm import LLM, SamplingParams import queue from torch.utils.data import DataLoader @@ -65,9 +66,9 @@ def prepare_dataloader( ) return dataloader -def load_model(max_num_frames: int, max_tokens: int, num_devices: int): +def load_model(max_num_frames: int, max_tokens: int, num_devices: int, download_dir: Optional[str] = None): vllm_engine = LLM( - "Qwen/Qwen2-VL-2B-Instruct", tensor_parallel_size=num_devices, limit_mm_per_prompt={"image": max_num_frames} + "Qwen/Qwen2-VL-2B-Instruct", tensor_parallel_size=num_devices, limit_mm_per_prompt={"image": max_num_frames}, download_dir=download_dir ) sampling_params = SamplingParams(max_tokens=max_tokens) return vllm_engine, sampling_params @@ -83,11 +84,12 @@ def main( video_extensions: tuple = (".mp4"), num_data_workers:int = 4, batch_size:int = 8, - num_artifact_workers:int = 4 + num_artifact_workers:int = 4, + download_dir: Optional[str] = None, ): max_allowed_imgs_per_req = batch_size * max_num_frames vllm_engine, sampling_params = load_model( - max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens, num_devices=num_devices + max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens, num_devices=num_devices, download_dir=download_dir, ) dataloader = prepare_dataloader( video_root_dir=root_dir, @@ -118,4 +120,4 @@ def main( if __name__ == "__main__": - fire.Fire(main) \ No newline at end of file + fire.Fire(main) From 61d269a9c414cf6a36cd07ef5b881cbee3fc1bf8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 17 Oct 2024 13:49:26 +0200 Subject: [PATCH 06/14] make style --- video_recaptioning/dataset.py | 51 ++++++++++++--------- video_recaptioning/recaption.py | 78 +++++++++++++++++++-------------- 2 files changed, 76 insertions(+), 53 deletions(-) diff --git a/video_recaptioning/dataset.py b/video_recaptioning/dataset.py index 4322e9c6..ff226ed3 100644 --- a/video_recaptioning/dataset.py +++ b/video_recaptioning/dataset.py @@ -1,17 +1,18 @@ -import decord - -from torch.utils.data import Dataset -from PIL import Image -from typing import Tuple, List -import os -import io import base64 +import io +import os import sys +from typing import List, Tuple + +import decord +from PIL import Image +from torch.utils.data import Dataset + class VideoDataset(Dataset): def __init__( - self, root_video_dir: str, output_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4") - ): + self, root_video_dir: str, output_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4") + ): self.root_video_dir = root_video_dir self.max_num_frames = max_num_frames @@ -20,19 +21,25 @@ def __init__( os.path.join(root_video_dir, f) for f in os.listdir(root_video_dir) if f.endswith(video_extensions) } if os.path.isdir(output_dir): - existing_caption_basenames = {os.path.splitext(f)[0] for f in os.listdir(output_dir) if "_caption.txt" in f} + existing_caption_basenames = { + os.path.splitext(f)[0] for f in os.listdir(output_dir) if "_caption.txt" in f + } else: existing_caption_basenames = None if existing_caption_basenames: if len(existing_caption_basenames) == len(video_files): - sys.exit("It seems like all the input videos have been already captioned. So, we're exiting the program.") + sys.exit( + "It seems like all the input videos have been already captioned. So, we're exiting the program." + ) filtered_video_files = [ - f for f in video_files if os.path.splitext(os.path.basename(f))[0] + "_caption" not in existing_caption_basenames + f + for f in video_files + if os.path.splitext(os.path.basename(f))[0] + "_caption" not in existing_caption_basenames ] if len(video_files) > len(filtered_video_files): diff = len(video_files) - len(filtered_video_files) - print(f"Found existing captions for {diff} videos. Will skip them.") - + print(f"Found existing captions for {diff} videos. Will skip them.") + self.video_files = sorted(filtered_video_files) else: self.video_files = sorted(video_files) @@ -44,14 +51,14 @@ def __len__(self) -> int: def __getitem__(self, index: int) -> List[Image.Image]: video_path = self.video_files[index] return self.load_video(video_path) - + def load_video(self, path: str) -> List[Image.Image]: video_reader = decord.VideoReader(uri=path) base_name = os.path.basename(path).split(".")[0] - video_frames = [ - Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader)) - ][:self.max_num_frames] + video_frames = [Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader))][ + : self.max_num_frames + ] return {"video": [self.encode_image(frame) for frame in video_frames], "video_name": base_name} def encode_image(self, image): @@ -60,18 +67,20 @@ def encode_image(self, image): image_bytes = buffered.getvalue() return base64.b64encode(image_bytes).decode("utf-8") + if __name__ == "__main__": - from huggingface_hub import snapshot_download import tempfile + from huggingface_hub import snapshot_download + with tempfile.TemporaryDirectory() as tmpdirname: video_root_dir = snapshot_download( repo_id="Wild-Heart/Disney-VideoGeneration-Dataset", repo_type="dataset", local_dir=tmpdirname ) - + dataset = VideoDataset(os.path.join(video_root_dir, "videos"), max_num_frames=16) print(len(dataset)) for item in dataset: print(len(item["video"])) - break \ No newline at end of file + break diff --git a/video_recaptioning/recaption.py b/video_recaptioning/recaption.py index 8923eb9f..64b58617 100644 --- a/video_recaptioning/recaption.py +++ b/video_recaptioning/recaption.py @@ -2,15 +2,16 @@ Needs `vllm` to be installed from the `main`. """ -from typing import Optional -from vllm import LLM, SamplingParams +import os import queue -from torch.utils.data import DataLoader from concurrent.futures import ThreadPoolExecutor -import fire -import os +from typing import Optional +import fire from dataset import VideoDataset +from torch.utils.data import DataLoader +from vllm import LLM, SamplingParams + def save_results(output_queue, output_dir): while True: @@ -29,28 +30,34 @@ def save_results(output_queue, output_dir): except queue.Empty: continue + def create_conversations(batch, prompt: str): conversations = [] for i, video in enumerate(batch["videos"]): content = [] - content.append({"type": "text", "text": "Describe this set of frames. Consider the frames to be a part of the same video."}) + content.append( + { + "type": "text", + "text": "Describe this set of frames. Consider the frames to be a part of the same video.", + } + ) for j in range(len(video)): new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{video[j]}"}} content.append(new_image) message = {"role": "user", "content": content} - conversations.append([message]) + conversations.append([message]) return conversations + def collate_fn(batch): inputs = { "videos": [sample["video"] for sample in batch], - "video_names": [sample["video_name"] for sample in batch] + "video_names": [sample["video_name"] for sample in batch], } return inputs -def prepare_dataloader( - video_root_dir, output_dir, video_extensions, max_num_frames, num_data_workers, batch_size - ): + +def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_frames, num_data_workers, batch_size): dataset = VideoDataset( video_root_dir, output_dir=output_dir, max_num_frames=max_num_frames, video_extensions=video_extensions ) @@ -58,46 +65,53 @@ def prepare_dataloader( # Create DataLoader dataloader = DataLoader( dataset, - batch_size=batch_size, + batch_size=batch_size, collate_fn=collate_fn, num_workers=num_data_workers, pin_memory=True, - persistent_workers=True + persistent_workers=True, ) return dataloader + def load_model(max_num_frames: int, max_tokens: int, num_devices: int, download_dir: Optional[str] = None): vllm_engine = LLM( - "Qwen/Qwen2-VL-2B-Instruct", tensor_parallel_size=num_devices, limit_mm_per_prompt={"image": max_num_frames}, download_dir=download_dir + "Qwen/Qwen2-VL-2B-Instruct", + tensor_parallel_size=num_devices, + limit_mm_per_prompt={"image": max_num_frames}, + download_dir=download_dir, ) sampling_params = SamplingParams(max_tokens=max_tokens) return vllm_engine, sampling_params def main( - root_dir: str, - prompt: str, - output_dir: str, - num_devices: int, - max_num_frames: int, - max_tokens: int, - video_extensions: tuple = (".mp4"), - num_data_workers:int = 4, - batch_size:int = 8, - num_artifact_workers:int = 4, - download_dir: Optional[str] = None, - ): + root_dir: str, + prompt: str, + output_dir: str, + num_devices: int, + max_num_frames: int, + max_tokens: int, + video_extensions: tuple = (".mp4"), + num_data_workers: int = 4, + batch_size: int = 8, + num_artifact_workers: int = 4, + download_dir: Optional[str] = None, +): max_allowed_imgs_per_req = batch_size * max_num_frames vllm_engine, sampling_params = load_model( - max_num_frames=max_allowed_imgs_per_req, max_tokens=max_tokens, num_devices=num_devices, download_dir=download_dir, + max_num_frames=max_allowed_imgs_per_req, + max_tokens=max_tokens, + num_devices=num_devices, + download_dir=download_dir, ) dataloader = prepare_dataloader( - video_root_dir=root_dir, + video_root_dir=root_dir, output_dir=output_dir, video_extensions=video_extensions, - max_num_frames=max_num_frames, - num_data_workers=num_data_workers, - batch_size=batch_size + max_num_frames=max_num_frames, + num_data_workers=num_data_workers, + batch_size=batch_size, ) output_queue = queue.Queue() @@ -117,7 +131,7 @@ def main( save_future.result() print("All processes completed. Caption generation and saving done.") - + if __name__ == "__main__": fire.Fire(main) From 218f48c992d5c9ce2775e0a9a5782ae749147504 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 11 Nov 2024 23:36:32 +0100 Subject: [PATCH 07/14] update --- {video_recaptioning => captioning}/dataset.py | 4 +- {video_recaptioning => captioning}/launch.sh | 0 .../recaption.py | 83 ++++++++++++++++--- tests/test_lora_inference.py | 1 + training/cogvideox_text_to_video_lora.py | 2 +- training/cogvideox_text_to_video_sft.py | 2 +- 6 files changed, 76 insertions(+), 16 deletions(-) rename {video_recaptioning => captioning}/dataset.py (99%) rename {video_recaptioning => captioning}/launch.sh (100%) rename {video_recaptioning => captioning}/recaption.py (50%) diff --git a/video_recaptioning/dataset.py b/captioning/dataset.py similarity index 99% rename from video_recaptioning/dataset.py rename to captioning/dataset.py index ff226ed3..26ef56e4 100644 --- a/video_recaptioning/dataset.py +++ b/captioning/dataset.py @@ -4,11 +4,13 @@ import sys from typing import List, Tuple -import decord from PIL import Image from torch.utils.data import Dataset +import decord # isort:skip + + class VideoDataset(Dataset): def __init__( self, root_video_dir: str, output_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4") diff --git a/video_recaptioning/launch.sh b/captioning/launch.sh similarity index 100% rename from video_recaptioning/launch.sh rename to captioning/launch.sh diff --git a/video_recaptioning/recaption.py b/captioning/recaption.py similarity index 50% rename from video_recaptioning/recaption.py rename to captioning/recaption.py index 64b58617..6d2ea9b5 100644 --- a/video_recaptioning/recaption.py +++ b/captioning/recaption.py @@ -13,6 +13,23 @@ from vllm import LLM, SamplingParams +SYSTEM_PROMPT = r""" +You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. +For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. +There are a few rules to follow: +- You will only ever output a single video description per request. +- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. +You responses should just be the video generation prompt. Here are examples: +- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." +- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" +""".strip() + + +USER_PROMPT = r""" +Summarize this set of frames. Consider these frames to be part of the same video. Please limit the summary to 100 words. Describe the content generally and do not start with phrases like "This video is about" or "This video captures". +""" + + def save_results(output_queue, output_dir): while True: try: @@ -20,32 +37,62 @@ def save_results(output_queue, output_dir): if item is None: break - video_names, outputs = item + video_filenames, outputs = item outputs = [o.outputs[0].text for o in outputs] - for i, pred_caption in enumerate(outputs): - with open(os.path.join(output_dir, f"{video_names[i]}_caption.txt"), "w") as f: - f.write(pred_caption) + with open(os.path.join(output_dir, "videos.txt"), "a") as file: + for filename in video_filenames: + file.write(filename + "\n") + + with open(os.path.join(output_dir, "captions.txt"), "a") as file: + for caption in outputs: + file.write(caption + "\n") except queue.Empty: continue -def create_conversations(batch, prompt: str): +def create_conversations(batch, prompt: Optional[str] = None): + if prompt is None: + prompt = USER_PROMPT + conversations = [] + for i, video in enumerate(batch["videos"]): + conversation = [] content = [] + content.append( { "type": "text", - "text": "Describe this set of frames. Consider the frames to be a part of the same video.", + "text": prompt, } ) - for j in range(len(video)): - new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{video[j]}"}} + for frame in video: + new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}} content.append(new_image) - message = {"role": "user", "content": content} - conversations.append([message]) + + conversation.append( + { + "role": "system", + "content": [ + { + "type": "text", + "text": SYSTEM_PROMPT, + } + ], + } + ) + + conversation.append( + { + "role": "user", + "content": content, + } + ) + + conversations.append(conversation) + return conversations @@ -74,12 +121,20 @@ def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_fra return dataloader -def load_model(max_num_frames: int, max_tokens: int, num_devices: int, download_dir: Optional[str] = None): +def load_model( + max_num_frames: int, + max_tokens: int, + num_devices: int, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): vllm_engine = LLM( - "Qwen/Qwen2-VL-2B-Instruct", + "openbmb/MiniCPM-V-2_6", + # "Qwen/Qwen2-VL-2B-Instruct", tensor_parallel_size=num_devices, limit_mm_per_prompt={"image": max_num_frames}, download_dir=download_dir, + trust_remote_code=trust_remote_code, ) sampling_params = SamplingParams(max_tokens=max_tokens) return vllm_engine, sampling_params @@ -87,16 +142,17 @@ def load_model(max_num_frames: int, max_tokens: int, num_devices: int, download_ def main( root_dir: str, - prompt: str, output_dir: str, num_devices: int, max_num_frames: int, max_tokens: int, + prompt: Optional[str] = None, video_extensions: tuple = (".mp4"), num_data_workers: int = 4, batch_size: int = 8, num_artifact_workers: int = 4, download_dir: Optional[str] = None, + trust_remote_code: bool = False, ): max_allowed_imgs_per_req = batch_size * max_num_frames vllm_engine, sampling_params = load_model( @@ -104,6 +160,7 @@ def main( max_tokens=max_tokens, num_devices=num_devices, download_dir=download_dir, + trust_remote_code=trust_remote_code, ) dataloader = prepare_dataloader( video_root_dir=root_dir, diff --git a/tests/test_lora_inference.py b/tests/test_lora_inference.py index 31428a4b..d1c0439b 100644 --- a/tests/test_lora_inference.py +++ b/tests/test_lora_inference.py @@ -8,6 +8,7 @@ """ import argparse + import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video diff --git a/training/cogvideox_text_to_video_lora.py b/training/cogvideox_text_to_video_lora.py index cf678a5a..2e071440 100644 --- a/training/cogvideox_text_to_video_lora.py +++ b/training/cogvideox_text_to_video_lora.py @@ -62,7 +62,7 @@ print_memory, reset_memory, unwrap_model, -) # isort:skip +) logger = get_logger(__name__) diff --git a/training/cogvideox_text_to_video_sft.py b/training/cogvideox_text_to_video_sft.py index 9b442175..01563019 100644 --- a/training/cogvideox_text_to_video_sft.py +++ b/training/cogvideox_text_to_video_sft.py @@ -61,7 +61,7 @@ print_memory, reset_memory, unwrap_model, -) # isort:skip +) logger = get_logger(__name__) From 415b07e69360fd8eee37ab847a3aeb1a31e9a378 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 12 Nov 2024 14:35:27 +0100 Subject: [PATCH 08/14] update --- Makefile | 2 +- captioning/dataset.py | 5 +- captioning/{recaption.py => vllm_caption.py} | 148 +++++++++++++------ 3 files changed, 108 insertions(+), 47 deletions(-) rename captioning/{recaption.py => vllm_caption.py} (59%) diff --git a/Makefile b/Makefile index 9bc09257..d0c7cb1e 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: quality style -check_dirs := training tests video_recaptioning +check_dirs := training tests captioning quality: ruff check $(check_dirs) diff --git a/captioning/dataset.py b/captioning/dataset.py index 26ef56e4..ca379e89 100644 --- a/captioning/dataset.py +++ b/captioning/dataset.py @@ -56,12 +56,13 @@ def __getitem__(self, index: int) -> List[Image.Image]: def load_video(self, path: str) -> List[Image.Image]: video_reader = decord.VideoReader(uri=path) - base_name = os.path.basename(path).split(".")[0] + filename = os.path.basename(path) video_frames = [Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader))][ : self.max_num_frames ] - return {"video": [self.encode_image(frame) for frame in video_frames], "video_name": base_name} + + return {"video": [self.encode_image(frame) for frame in video_frames], "filename": filename} def encode_image(self, image): buffered = io.BytesIO() diff --git a/captioning/recaption.py b/captioning/vllm_caption.py similarity index 59% rename from captioning/recaption.py rename to captioning/vllm_caption.py index 6d2ea9b5..4c975506 100644 --- a/captioning/recaption.py +++ b/captioning/vllm_caption.py @@ -2,16 +2,19 @@ Needs `vllm` to be installed from the `main`. """ +import gc import os import queue from concurrent.futures import ThreadPoolExecutor from typing import Optional import fire -from dataset import VideoDataset +import torch from torch.utils.data import DataLoader from vllm import LLM, SamplingParams +from dataset import VideoDataset # isort:skip + SYSTEM_PROMPT = r""" You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. @@ -24,10 +27,17 @@ - "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" """.strip() +SUMMARY_USER_PROMPT = r"""Please summarize this video and limit the summary to 100-200 words.""".strip() -USER_PROMPT = r""" -Summarize this set of frames. Consider these frames to be part of the same video. Please limit the summary to 100 words. Describe the content generally and do not start with phrases like "This video is about" or "This video captures". -""" +PROMPT_GEN_USER_PROMPT = r""" +Could you generate a prompt for a video generation model given the following summary: + +``` +{0} +``` + +Please limit the prompt to [{1}] words. +""".strip() def save_results(output_queue, output_dir): @@ -38,7 +48,6 @@ def save_results(output_queue, output_dir): break video_filenames, outputs = item - outputs = [o.outputs[0].text for o in outputs] with open(os.path.join(output_dir, "videos.txt"), "a") as file: for filename in video_filenames: @@ -52,9 +61,9 @@ def save_results(output_queue, output_dir): continue -def create_conversations(batch, prompt: Optional[str] = None): +def create_video_summary_conversations(batch, prompt: Optional[str] = None): if prompt is None: - prompt = USER_PROMPT + prompt = SUMMARY_USER_PROMPT conversations = [] @@ -62,34 +71,36 @@ def create_conversations(batch, prompt: Optional[str] = None): conversation = [] content = [] - content.append( - { - "type": "text", - "text": prompt, - } - ) + content.append({"type": "text", "text": prompt}) for frame in video: new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}} content.append(new_image) - conversation.append( - { - "role": "system", - "content": [ - { - "type": "text", - "text": SYSTEM_PROMPT, - } - ], - } - ) + # conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) + conversation.append({"role": "user", "content": content}) - conversation.append( - { - "role": "user", - "content": content, - } - ) + conversations.append(conversation) + + return conversations + + +def create_prompt_generation_conversations(batch, prompt: Optional[str] = None): + if prompt is None: + prompt = PROMPT_GEN_USER_PROMPT + + conversations = [] + + for i, summary in enumerate(batch["summary"]): + conversation = [] + content = [] + + content.append({ + "type": "text", + "text": prompt.format(summary, 20) + }) + + conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) + conversation.append({"role": "user", "content": content}) conversations.append(conversation) @@ -99,7 +110,7 @@ def create_conversations(batch, prompt: Optional[str] = None): def collate_fn(batch): inputs = { "videos": [sample["video"] for sample in batch], - "video_names": [sample["video_name"] for sample in batch], + "filename": [sample["filename"] for sample in batch], } return inputs @@ -121,32 +132,51 @@ def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_fra return dataloader -def load_model( +def load_summary_model( max_num_frames: int, max_tokens: int, num_devices: int, download_dir: Optional[str] = None, trust_remote_code: bool = False, ): - vllm_engine = LLM( + engine = LLM( "openbmb/MiniCPM-V-2_6", - # "Qwen/Qwen2-VL-2B-Instruct", + dtype="bfloat16", tensor_parallel_size=num_devices, limit_mm_per_prompt={"image": max_num_frames}, download_dir=download_dir, trust_remote_code=trust_remote_code, ) sampling_params = SamplingParams(max_tokens=max_tokens) - return vllm_engine, sampling_params + return engine, sampling_params + + +def load_prompt_gen_model( + max_tokens: int, + num_devices: int, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + engine = LLM( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + dtype="bfloat16", + tensor_parallel_size=num_devices, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + sampling_params = SamplingParams(max_tokens=max_tokens) + return engine, sampling_params def main( root_dir: str, output_dir: str, - num_devices: int, - max_num_frames: int, - max_tokens: int, - prompt: Optional[str] = None, + num_devices: int = 1, + max_num_frames: int = 8, + max_summary_tokens: int = 512, + max_prompt_gen_tokens: int = 256, + video_summary_prompt: Optional[str] = None, + prompt_gen_prompt: Optional[str] = None, video_extensions: tuple = (".mp4"), num_data_workers: int = 4, batch_size: int = 8, @@ -155,13 +185,15 @@ def main( trust_remote_code: bool = False, ): max_allowed_imgs_per_req = batch_size * max_num_frames - vllm_engine, sampling_params = load_model( + + summary_engine, summary_sampling_params = load_summary_model( max_num_frames=max_allowed_imgs_per_req, - max_tokens=max_tokens, + max_tokens=max_summary_tokens, num_devices=num_devices, download_dir=download_dir, trust_remote_code=trust_remote_code, ) + dataloader = prepare_dataloader( video_root_dir=root_dir, output_dir=output_dir, @@ -177,10 +209,38 @@ def main( save_future = save_thread.submit(save_results, output_queue, output_dir) try: + video_data = [] + for idx, batch in enumerate(dataloader): - conversations = create_conversations(batch, prompt=prompt) - outputs = vllm_engine.chat(conversations, sampling_params) - output_queue.put((batch["video_names"], outputs)) + conversations = create_video_summary_conversations(batch, prompt=video_summary_prompt) + video_summaries = summary_engine.chat(conversations, summary_sampling_params) + + video_data_item = { + "filename": batch["filename"], + "summary": [summary.outputs[0].text for summary in video_summaries] + } + + video_data.append(video_data_item) + + del summary_engine, summary_sampling_params + gc.collect() + torch.cuda.empty_cache() + + prompt_gen_engine, prompt_gen_sampling_params = load_prompt_gen_model( + max_tokens=max_prompt_gen_tokens, + num_devices=num_devices, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + + for idx, batch in enumerate(video_data): + conversations = create_prompt_generation_conversations(batch, prompt=prompt_gen_prompt) + prompts = prompt_gen_engine.chat(conversations, prompt_gen_sampling_params) + + # Get outputs and remove surrounding quotes + prompts = [prompt.outputs[0].text[1 : -1] for prompt in prompts] + + output_queue.put((batch["filename"], prompts)) finally: output_queue.put(None) From a0542f091f79f2f1b6c916339ecfd6d4186e1cb3 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 13 Nov 2024 00:46:43 +0000 Subject: [PATCH 09/14] update --- captioning/dataset_caption.py | 25 +++ captioning/{dataset.py => dataset_video.py} | 15 +- captioning/vllm_caption.py | 123 +++---------- captioning/vllm_summary.py | 188 ++++++++++++++++++++ training/dataset.py | 4 +- 5 files changed, 249 insertions(+), 106 deletions(-) create mode 100644 captioning/dataset_caption.py rename captioning/{dataset.py => dataset_video.py} (86%) create mode 100644 captioning/vllm_summary.py diff --git a/captioning/dataset_caption.py b/captioning/dataset_caption.py new file mode 100644 index 00000000..9d642bfe --- /dev/null +++ b/captioning/dataset_caption.py @@ -0,0 +1,25 @@ +import pathlib + +import pandas as pd +from torch.utils.data import Dataset + + +class CaptionDataset(Dataset): + def __init__(self, input_file: str) -> None: + self.input_file = pathlib.Path(input_file) + + assert self.input_file.is_file() + + df = pd.read_csv(input_file) + self.filenames = df["filename"] + self.summaries = df["summary"] + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index: int): + return { + "filename": self.filenames[index], + "summary": self.summaries[index], + } + diff --git a/captioning/dataset.py b/captioning/dataset_video.py similarity index 86% rename from captioning/dataset.py rename to captioning/dataset_video.py index ca379e89..499fb47b 100644 --- a/captioning/dataset.py +++ b/captioning/dataset_video.py @@ -58,9 +58,18 @@ def load_video(self, path: str) -> List[Image.Image]: video_reader = decord.VideoReader(uri=path) filename = os.path.basename(path) - video_frames = [Image.fromarray(video_reader[i].asnumpy()) for i in range(len(video_reader))][ - : self.max_num_frames - ] + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + sample_fps = round(video_reader.get_avg_fps() / 1) + frame_idx = [i for i in range(0, len(video_reader), sample_fps)] + + if len(frame_idx) > self.max_num_frames: + frame_idx = uniform_sample(frame_idx, self.max_num_frames) + + video_frames = [Image.fromarray(video_reader[i].asnumpy()) for i in frame_idx] return {"video": [self.encode_image(frame) for frame in video_frames], "filename": filename} diff --git a/captioning/vllm_caption.py b/captioning/vllm_caption.py index 4c975506..5308f44b 100644 --- a/captioning/vllm_caption.py +++ b/captioning/vllm_caption.py @@ -13,22 +13,20 @@ from torch.utils.data import DataLoader from vllm import LLM, SamplingParams -from dataset import VideoDataset # isort:skip +from dataset_caption import CaptionDataset # isort:skip SYSTEM_PROMPT = r""" You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. There are a few rules to follow: -- You will only ever output a single video description per request. +- You will only ever output a single video description per request. Do not use newlines. - If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. You responses should just be the video generation prompt. Here are examples: - "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." - "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" """.strip() -SUMMARY_USER_PROMPT = r"""Please summarize this video and limit the summary to 100-200 words.""".strip() - PROMPT_GEN_USER_PROMPT = r""" Could you generate a prompt for a video generation model given the following summary: @@ -47,56 +45,32 @@ def save_results(output_queue, output_dir): if item is None: break - video_filenames, outputs = item + video_filenames, captions = item - with open(os.path.join(output_dir, "videos.txt"), "a") as file: + with open(os.path.join(output_dir, "videos.txt"), "a", encoding="utf-8") as file: for filename in video_filenames: file.write(filename + "\n") - - with open(os.path.join(output_dir, "captions.txt"), "a") as file: - for caption in outputs: + + with open(os.path.join(output_dir, "prompts.txt"), "a", encoding="utf-8") as file: + for caption in captions: file.write(caption + "\n") - except queue.Empty: continue -def create_video_summary_conversations(batch, prompt: Optional[str] = None): - if prompt is None: - prompt = SUMMARY_USER_PROMPT - - conversations = [] - - for i, video in enumerate(batch["videos"]): - conversation = [] - content = [] - - content.append({"type": "text", "text": prompt}) - for frame in video: - new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}} - content.append(new_image) - - # conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) - conversation.append({"role": "user", "content": content}) - - conversations.append(conversation) - - return conversations - - def create_prompt_generation_conversations(batch, prompt: Optional[str] = None): if prompt is None: prompt = PROMPT_GEN_USER_PROMPT conversations = [] - for i, summary in enumerate(batch["summary"]): + for i, summary in enumerate(batch): conversation = [] content = [] content.append({ "type": "text", - "text": prompt.format(summary, 20) + "text": prompt.format(summary, 50) }) conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) @@ -109,16 +83,14 @@ def create_prompt_generation_conversations(batch, prompt: Optional[str] = None): def collate_fn(batch): inputs = { - "videos": [sample["video"] for sample in batch], + "summary": [sample["summary"] for sample in batch], "filename": [sample["filename"] for sample in batch], } return inputs -def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_frames, num_data_workers, batch_size): - dataset = VideoDataset( - video_root_dir, output_dir=output_dir, max_num_frames=max_num_frames, video_extensions=video_extensions - ) +def prepare_dataloader(input_file, num_data_workers, batch_size): + dataset = CaptionDataset(input_file) # Create DataLoader dataloader = DataLoader( @@ -132,25 +104,6 @@ def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_fra return dataloader -def load_summary_model( - max_num_frames: int, - max_tokens: int, - num_devices: int, - download_dir: Optional[str] = None, - trust_remote_code: bool = False, -): - engine = LLM( - "openbmb/MiniCPM-V-2_6", - dtype="bfloat16", - tensor_parallel_size=num_devices, - limit_mm_per_prompt={"image": max_num_frames}, - download_dir=download_dir, - trust_remote_code=trust_remote_code, - ) - sampling_params = SamplingParams(max_tokens=max_tokens) - return engine, sampling_params - - def load_prompt_gen_model( max_tokens: int, num_devices: int, @@ -169,36 +122,26 @@ def load_prompt_gen_model( def main( - root_dir: str, + input_file: str, output_dir: str, num_devices: int = 1, - max_num_frames: int = 8, - max_summary_tokens: int = 512, max_prompt_gen_tokens: int = 256, - video_summary_prompt: Optional[str] = None, prompt_gen_prompt: Optional[str] = None, - video_extensions: tuple = (".mp4"), - num_data_workers: int = 4, batch_size: int = 8, + num_data_workers: int = 4, num_artifact_workers: int = 4, download_dir: Optional[str] = None, trust_remote_code: bool = False, ): - max_allowed_imgs_per_req = batch_size * max_num_frames - - summary_engine, summary_sampling_params = load_summary_model( - max_num_frames=max_allowed_imgs_per_req, - max_tokens=max_summary_tokens, + prompt_gen_engine, prompt_gen_sampling_params = load_prompt_gen_model( + max_tokens=max_prompt_gen_tokens, num_devices=num_devices, download_dir=download_dir, trust_remote_code=trust_remote_code, ) dataloader = prepare_dataloader( - video_root_dir=root_dir, - output_dir=output_dir, - video_extensions=video_extensions, - max_num_frames=max_num_frames, + input_file=input_file, num_data_workers=num_data_workers, batch_size=batch_size, ) @@ -209,39 +152,15 @@ def main( save_future = save_thread.submit(save_results, output_queue, output_dir) try: - video_data = [] - for idx, batch in enumerate(dataloader): - conversations = create_video_summary_conversations(batch, prompt=video_summary_prompt) - video_summaries = summary_engine.chat(conversations, summary_sampling_params) - - video_data_item = { - "filename": batch["filename"], - "summary": [summary.outputs[0].text for summary in video_summaries] - } - - video_data.append(video_data_item) - - del summary_engine, summary_sampling_params - gc.collect() - torch.cuda.empty_cache() - - prompt_gen_engine, prompt_gen_sampling_params = load_prompt_gen_model( - max_tokens=max_prompt_gen_tokens, - num_devices=num_devices, - download_dir=download_dir, - trust_remote_code=trust_remote_code, - ) - - for idx, batch in enumerate(video_data): - conversations = create_prompt_generation_conversations(batch, prompt=prompt_gen_prompt) + conversations = create_prompt_generation_conversations(batch["summary"], prompt=prompt_gen_prompt) prompts = prompt_gen_engine.chat(conversations, prompt_gen_sampling_params) - # Get outputs and remove surrounding quotes - prompts = [prompt.outputs[0].text[1 : -1] for prompt in prompts] + # Get outputs and remove surrounding quotes/newlines + prompts = [" ".join(prompt.outputs[0].text.split("\n")) for prompt in prompts] + prompts = [prompt.lstrip("\"").rstrip("\"") for prompt in prompts] output_queue.put((batch["filename"], prompts)) - finally: output_queue.put(None) save_thread.shutdown(wait=True) diff --git a/captioning/vllm_summary.py b/captioning/vllm_summary.py new file mode 100644 index 00000000..f1ee0c58 --- /dev/null +++ b/captioning/vllm_summary.py @@ -0,0 +1,188 @@ +""" +Needs `vllm` to be installed from the `main`. +""" + +import os +import pathlib +import queue +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import fire +import pandas as pd +from torch.utils.data import DataLoader +from vllm import LLM, SamplingParams + +from dataset_video import VideoDataset # isort:skip + + +SYSTEM_PROMPT = r""" +You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. +For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. +There are a few rules to follow: +- You will only ever output a single video description per request. +- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. +You responses should just be the video generation prompt. Here are examples: +- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." +- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" +""".strip() + +SUMMARY_USER_PROMPT = r"""Please summarize this video and limit the summary to 100-200 words.""".strip() + +PROMPT_GEN_USER_PROMPT = r""" +Could you generate a prompt for a video generation model given the following summary: + +``` +{0} +``` + +Please limit the prompt to [{1}] words. +""".strip() + + +def save_results(output_queue, output_dir): + output_file = pathlib.Path(output_dir).joinpath("summary.csv") + + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + + video_filenames, summaries = item + + df_data = [] + for filename, summary in zip(video_filenames, summaries): + df_data.append({ + "filename": filename, + "summary": summary, + }) + + df = pd.DataFrame(df_data) + df.to_csv(output_file, mode="a", header=not output_file.exists(), index=False) + + except queue.Empty: + continue + + +def create_video_summary_conversations(batch, prompt: Optional[str] = None): + if prompt is None: + prompt = SUMMARY_USER_PROMPT + + conversations = [] + + for i, video in enumerate(batch["videos"]): + conversation = [] + content = [] + + content.append({"type": "text", "text": prompt}) + for frame in video: + new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}} + content.append(new_image) + + # conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) + conversation.append({"role": "user", "content": content}) + + conversations.append(conversation) + + return conversations + + +def collate_fn(batch): + inputs = { + "videos": [sample["video"] for sample in batch], + "filename": [sample["filename"] for sample in batch], + } + return inputs + + +def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_frames, num_data_workers, batch_size): + dataset = VideoDataset( + video_root_dir, output_dir=output_dir, max_num_frames=max_num_frames, video_extensions=video_extensions + ) + + # Create DataLoader + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_data_workers, + pin_memory=True, + persistent_workers=True, + ) + return dataloader + + +def load_summary_model( + max_num_frames: int, + max_tokens: int, + num_devices: int, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + engine = LLM( + "openbmb/MiniCPM-V-2_6", + dtype="bfloat16", + tensor_parallel_size=num_devices, + limit_mm_per_prompt={"image": max_num_frames}, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + sampling_params = SamplingParams(min_tokens=64, max_tokens=max_tokens) + return engine, sampling_params + + +def main( + root_dir: str, + output_dir: str, + num_devices: int = 1, + max_num_frames: int = 8, + max_summary_tokens: int = 512, + video_summary_prompt: Optional[str] = None, + video_extensions: tuple = (".mp4"), + batch_size: int = 8, + num_data_workers: int = 4, + num_artifact_workers: int = 4, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + max_allowed_imgs_per_req = batch_size * max_num_frames + + summary_engine, summary_sampling_params = load_summary_model( + max_num_frames=max_allowed_imgs_per_req, + max_tokens=max_summary_tokens, + num_devices=num_devices, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + + dataloader = prepare_dataloader( + video_root_dir=root_dir, + output_dir=output_dir, + video_extensions=video_extensions, + max_num_frames=max_num_frames, + num_data_workers=num_data_workers, + batch_size=batch_size, + ) + + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=num_artifact_workers) + os.makedirs(output_dir, exist_ok=True) + save_future = save_thread.submit(save_results, output_queue, output_dir) + + try: + for idx, batch in enumerate(dataloader): + conversations = create_video_summary_conversations(batch, prompt=video_summary_prompt) + video_summaries = summary_engine.chat(conversations, summary_sampling_params) + summaries = [summary.outputs[0].text for summary in video_summaries] + output_queue.put((batch["filename"], summaries)) + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) + + save_future.result() + print("All processes completed. Caption generation and saving done.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/training/dataset.py b/training/dataset.py index dca375df..fa38de85 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -277,7 +277,9 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket)) From 5436d5840371aeb263f2234bf71191c924a54e74 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 13 Nov 2024 00:47:48 +0000 Subject: [PATCH 10/14] update --- captioning/dataset_caption.py | 5 ++--- captioning/dataset_video.py | 8 ++++---- captioning/vllm_caption.py | 16 ++++++---------- captioning/vllm_summary.py | 17 ++++++++++------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/captioning/dataset_caption.py b/captioning/dataset_caption.py index 9d642bfe..6806dee2 100644 --- a/captioning/dataset_caption.py +++ b/captioning/dataset_caption.py @@ -13,13 +13,12 @@ def __init__(self, input_file: str) -> None: df = pd.read_csv(input_file) self.filenames = df["filename"] self.summaries = df["summary"] - + def __len__(self): return len(self.filenames) - + def __getitem__(self, index: int): return { "filename": self.filenames[index], "summary": self.summaries[index], } - diff --git a/captioning/dataset_video.py b/captioning/dataset_video.py index 499fb47b..671372a5 100644 --- a/captioning/dataset_video.py +++ b/captioning/dataset_video.py @@ -62,15 +62,15 @@ def uniform_sample(l, n): gap = len(l) / n idxs = [int(i * gap + gap / 2) for i in range(n)] return [l[i] for i in idxs] - + sample_fps = round(video_reader.get_avg_fps() / 1) - frame_idx = [i for i in range(0, len(video_reader), sample_fps)] + frame_idx = list(range(0, len(video_reader), sample_fps)) if len(frame_idx) > self.max_num_frames: frame_idx = uniform_sample(frame_idx, self.max_num_frames) - + video_frames = [Image.fromarray(video_reader[i].asnumpy()) for i in frame_idx] - + return {"video": [self.encode_image(frame) for frame in video_frames], "filename": filename} def encode_image(self, image): diff --git a/captioning/vllm_caption.py b/captioning/vllm_caption.py index 5308f44b..234403f9 100644 --- a/captioning/vllm_caption.py +++ b/captioning/vllm_caption.py @@ -2,17 +2,16 @@ Needs `vllm` to be installed from the `main`. """ -import gc import os import queue from concurrent.futures import ThreadPoolExecutor from typing import Optional import fire -import torch from torch.utils.data import DataLoader from vllm import LLM, SamplingParams + from dataset_caption import CaptionDataset # isort:skip @@ -50,7 +49,7 @@ def save_results(output_queue, output_dir): with open(os.path.join(output_dir, "videos.txt"), "a", encoding="utf-8") as file: for filename in video_filenames: file.write(filename + "\n") - + with open(os.path.join(output_dir, "prompts.txt"), "a", encoding="utf-8") as file: for caption in captions: file.write(caption + "\n") @@ -61,17 +60,14 @@ def save_results(output_queue, output_dir): def create_prompt_generation_conversations(batch, prompt: Optional[str] = None): if prompt is None: prompt = PROMPT_GEN_USER_PROMPT - + conversations = [] for i, summary in enumerate(batch): conversation = [] content = [] - content.append({ - "type": "text", - "text": prompt.format(summary, 50) - }) + content.append({"type": "text", "text": prompt.format(summary, 50)}) conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) conversation.append({"role": "user", "content": content}) @@ -155,10 +151,10 @@ def main( for idx, batch in enumerate(dataloader): conversations = create_prompt_generation_conversations(batch["summary"], prompt=prompt_gen_prompt) prompts = prompt_gen_engine.chat(conversations, prompt_gen_sampling_params) - + # Get outputs and remove surrounding quotes/newlines prompts = [" ".join(prompt.outputs[0].text.split("\n")) for prompt in prompts] - prompts = [prompt.lstrip("\"").rstrip("\"") for prompt in prompts] + prompts = [prompt.lstrip('"').rstrip('"') for prompt in prompts] output_queue.put((batch["filename"], prompts)) finally: diff --git a/captioning/vllm_summary.py b/captioning/vllm_summary.py index f1ee0c58..5606f80e 100644 --- a/captioning/vllm_summary.py +++ b/captioning/vllm_summary.py @@ -13,6 +13,7 @@ from torch.utils.data import DataLoader from vllm import LLM, SamplingParams + from dataset_video import VideoDataset # isort:skip @@ -53,11 +54,13 @@ def save_results(output_queue, output_dir): df_data = [] for filename, summary in zip(video_filenames, summaries): - df_data.append({ - "filename": filename, - "summary": summary, - }) - + df_data.append( + { + "filename": filename, + "summary": summary, + } + ) + df = pd.DataFrame(df_data) df.to_csv(output_file, mode="a", header=not output_file.exists(), index=False) @@ -147,7 +150,7 @@ def main( trust_remote_code: bool = False, ): max_allowed_imgs_per_req = batch_size * max_num_frames - + summary_engine, summary_sampling_params = load_summary_model( max_num_frames=max_allowed_imgs_per_req, max_tokens=max_summary_tokens, @@ -173,7 +176,7 @@ def main( try: for idx, batch in enumerate(dataloader): conversations = create_video_summary_conversations(batch, prompt=video_summary_prompt) - video_summaries = summary_engine.chat(conversations, summary_sampling_params) + video_summaries = summary_engine.chat(conversations, summary_sampling_params) summaries = [summary.outputs[0].text for summary in video_summaries] output_queue.put((batch["filename"], summaries)) finally: From 343d654b89b23ebb276c9a106c09bf628e8c12a2 Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 13 Nov 2024 13:51:38 +0000 Subject: [PATCH 11/14] update accelerate ocnfig --- accelerate_configs/uncompiled_8.yaml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 accelerate_configs/uncompiled_8.yaml diff --git a/accelerate_configs/uncompiled_8.yaml b/accelerate_configs/uncompiled_8.yaml new file mode 100644 index 00000000..27539bb5 --- /dev/null +++ b/accelerate_configs/uncompiled_8.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false From 415021cef21256c9b026d9f1bc14ba4fdb8321ac Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 13 Nov 2024 13:54:04 +0000 Subject: [PATCH 12/14] update slurm scripts --- slurm/prepare_dataset.slurm | 52 +++++++++++++++++++++++++++++++++++++ slurm/train_lora_t2v.slurm | 28 ++++++++++++++++++++ slurm/vllm_caption.slurm | 43 ++++++++++++++++++++++++++++++ slurm/vllm_summary.slurm | 45 ++++++++++++++++++++++++++++++++ 4 files changed, 168 insertions(+) create mode 100644 slurm/prepare_dataset.slurm create mode 100644 slurm/train_lora_t2v.slurm create mode 100755 slurm/vllm_caption.slurm create mode 100755 slurm/vllm_summary.slurm diff --git a/slurm/prepare_dataset.slurm b/slurm/prepare_dataset.slurm new file mode 100644 index 00000000..fba162aa --- /dev/null +++ b/slurm/prepare_dataset.slurm @@ -0,0 +1,52 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +MODEL_ID="THUDM/CogVideoX-5b" +NUM_GPUS=8 + +DATA_ROOT="/path/to/datasets/dataset-cakify-yt" +CAPTION_COLUMN="prompts.txt" +VIDEO_COLUMN="videos.txt" +OUTPUT_DIR="/path/to/datasets/dataset-cakify-yt-encoded" +HEIGHT_BUCKETS="480 720 768" +WIDTH_BUCKETS="480 720 768" +FRAME_BUCKETS="17 33 49" +MAX_NUM_FRAMES=49 +MAX_SEQUENCE_LENGTH=226 +TARGET_FPS=8 +BATCH_SIZE=1 +DTYPE=fp32 +ADDITIONAL_FLAGS="--save_image_latents --save_latents_and_embeddings" + +set -xe + +module load cuda/12.1 + +echo "Starting job" + +srun torchrun --nproc_per_node=$NUM_GPUS training/prepare_dataset.py \ + --model_id $MODEL_ID \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --output_dir $OUTPUT_DIR \ + --height_buckets $HEIGHT_BUCKETS \ + --width_buckets $WIDTH_BUCKETS \ + --frame_buckets $FRAME_BUCKETS \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_sequence_length $MAX_SEQUENCE_LENGTH \ + --target_fps $TARGET_FPS \ + --batch_size $BATCH_SIZE \ + --dtype $DTYPE \ + $ADDITIONAL_FLAGS + +echo "End time: $(date)" diff --git a/slurm/train_lora_t2v.slurm b/slurm/train_lora_t2v.slurm new file mode 100644 index 00000000..5d1d9209 --- /dev/null +++ b/slurm/train_lora_t2v.slurm @@ -0,0 +1,28 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=168:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +set -xe + +echo "START TIME: $(date)" + +# Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +module load cuda/12.1 + +echo "Starting job" + +srun dump_train_text_to_video_lora.sh + +echo "END TIME: $(date)" diff --git a/slurm/vllm_caption.slurm b/slurm/vllm_caption.slurm new file mode 100755 index 00000000..456d25c9 --- /dev/null +++ b/slurm/vllm_caption.slurm @@ -0,0 +1,43 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +INPUT_FILE="/path/to/datasets/dataset-cakify-yt/summary.csv" +OUTPUT_DIR="/path/to/datasets/dataset-cakify-yt/" +NUM_DEVICES=8 +BATCH_SIZE=4 +MAX_PROMPT_GEN_TOKENS=256 +NUM_DATA_WORKERS=8 +NUM_ARTIFACT_WORKERS=8 +ADDITIONAL_FLAGS="--trust_remote_code" + +PROMPT_GEN_PROMPT="Could you generate a prompt for a text-to-video generation model given the following summary:\n\n\`\`\`\n{0}\n\`\`\`\n\nPlease limit the prompt to [{1}] words. To provide some additional context, these summaries are generated for videos that depict cutting of tasty cakes disguised as a realistic looking object." + +set -xe + +module load cuda/12.1 + +echo "Starting job" + +srun python3 captioning/vllm_caption.py \ + --input_file $INPUT_FILE \ + --output_dir $OUTPUT_DIR \ + --num_devices $NUM_DEVICES \ + --max_prompt_gen_tokens $MAX_PROMPT_GEN_TOKENS \ + --prompt_gen_prompt "$PROMPT_GEN_PROMPT" \ + --batch_size $BATCH_SIZE \ + --num_data_workers $NUM_DATA_WORKERS \ + --num_artifact_workers $NUM_ARTIFACT_WORKERS \ + $ADDITIONAL_FLAGS + +echo "End time: $(date)" diff --git a/slurm/vllm_summary.slurm b/slurm/vllm_summary.slurm new file mode 100755 index 00000000..2260bb03 --- /dev/null +++ b/slurm/vllm_summary.slurm @@ -0,0 +1,45 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +ROOT_DIR="/path/to/datasets/dataset-cakify-yt/videos" +OUTPUT_DIR="/path/to/datasets/dataset-cakify-yt/" +NUM_DEVICES=4 +BATCH_SIZE=4 +MAX_NUM_FRAMES=16 +MAX_SUMMARY_TOKENS=2048 +NUM_DATA_WORKERS=8 +NUM_ARTIFACT_WORKERS=8 +ADDITIONAL_FLAGS="--trust_remote_code" + +VIDEO_SUMMARY_PROMPT="Summarize this video and limit to 500 words. The sequence of video frames are of a person cutting cakes shaped as different realistic objects. Make sure to mention about the cake cutting in the summary." + +set -xe + +module load cuda/12.1 + +echo "Starting job" + +srun python3 captioning/vllm_summary.py \ + --root_dir $ROOT_DIR \ + --output_dir $OUTPUT_DIR \ + --num_devices $NUM_DEVICES \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_summary_tokens $MAX_SUMMARY_TOKENS \ + --video_summary_prompt "$VIDEO_SUMMARY_PROMPT" \ + --batch_size $BATCH_SIZE \ + --num_data_workers $NUM_DATA_WORKERS \ + --num_artifact_workers $NUM_ARTIFACT_WORKERS \ + $ADDITIONAL_FLAGS + +echo "End time: $(date)" From 5ab3743b7985445185381c8deaf8fdae77127741 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 14 Nov 2024 13:27:04 +0000 Subject: [PATCH 13/14] update --- data/cakify_eval_prompts.csv | 27 ++++++ inference/text_to_video_lora.py | 148 ++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 data/cakify_eval_prompts.csv create mode 100644 inference/text_to_video_lora.py diff --git a/data/cakify_eval_prompts.csv b/data/cakify_eval_prompts.csv new file mode 100644 index 00000000..8d556d54 --- /dev/null +++ b/data/cakify_eval_prompts.csv @@ -0,0 +1,27 @@ +prompt,height,width,num_frames +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,480,17 +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,480,49 +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,720,17 +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,720,49 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,17 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,49 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,720,17 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,720,49 +"{id_token} A vibrant orange cake disguised as a Nike packaging box sits on a dark surface, meticulous in its detail and design, complete with a white swoosh and 'NIKE' logo. A person's hands, holding a knife, hover over the cake, ready to make a precise cut, amidst a simple and clean background.",480,480,17 +"{id_token} A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,33 +"{id_token} A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,49 +"{id_token} A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.",768,720,17 +"{id_token} A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.",768,720,33 +"{id_token} A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.",768,720,49 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,480,49 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,720,49 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,480,33 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,720,33 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,480,17 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,720,17 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,480,49 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,720,49 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,480,33 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,720,33 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,480,17 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,720,17 diff --git a/inference/text_to_video_lora.py b/inference/text_to_video_lora.py new file mode 100644 index 00000000..edabf9c3 --- /dev/null +++ b/inference/text_to_video_lora.py @@ -0,0 +1,148 @@ +import pathlib +from typing import Any, Dict, Optional + +import fire +import pandas as pd +import torch +from accelerate import Accelerator +from accelerate.utils import gather_object +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm + + +class PromptDataset(Dataset): + def __init__(self, filename: str, id_token: Optional[str] = None) -> None: + super().__init__() + + self.id_token = id_token or "" + + df = pd.read_csv(filename) + + self.prompts = df["prompt"] + self.heights = df["height"] + self.widths = df["width"] + self.num_frames = df["num_frames"] + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + prompt = self.prompts[index].format(id_token=self.id_token).strip() + return { + "prompt": prompt, + "height": self.heights[index], + "width": self.widths[index], + "num_frames": self.num_frames[index], + } + + +class CollateFunction: + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompt = [x["prompt"] for x in data] + height = [x["height"] for x in data] + width = [x["width"] for x in data] + num_frames = [x["num_frames"] for x in data] + + return { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + } + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def prompt_to_filename(x: str) -> str: + for c in ["\"", "'", ",", "\\", " "]: + x = x.replace(c, "-") + return x + + +def main( + dataset_file: str, + model_id: str = "THUDM/CogVideoX-5b", + lora_id: Optional[str] = None, + id_token: Optional[str] = None, + dtype: str = "bf16", + enable_model_cpu_offload: bool = False, + output_dir: str = "text_to_video_lora_outputs", + save_fps: int = 8, + seed: int = 42, +) -> None: + dataset = PromptDataset(dataset_file, id_token) + + collate_fn = CollateFunction() + dataloader = DataLoader( + dataset=dataset, + batch_size=1, + collate_fn=collate_fn, + shuffle=False, + num_workers=1, + ) + + output_dir: pathlib.Path = pathlib.Path(output_dir) + output_dir.mkdir(exist_ok=True) + + pipe = CogVideoXPipeline.from_pretrained( + model_id, + torch_dtype=DTYPE_MAPPING[dtype] + ) + + if lora_id is not None: + pipe.load_lora_weights(lora_id) + + accelerator = Accelerator() + + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=accelerator.device.index) + else: + pipe = pipe.to(accelerator.device) + + count = 0 + for _, data_raw in tqdm(enumerate(dataloader), total=len(dataloader)): + input_data = [] + + with accelerator.split_between_processes(data_raw) as data: + video = pipe( + prompt=data["prompt"][0], + height=data["height"][0], + width=data["width"][0], + num_frames=data["num_frames"][0], + num_inference_steps=50, + guidance_scale=6.0, + use_dynamic_cfg=False, + generator=torch.Generator().manual_seed(seed), + ).frames[0] + input_data.append(data.items()) + + video = gather_object(video) + input_data = gather_object(input_data) + + if accelerator.is_main_process: + count += 1 + input_data = dict(input_data[0]) + + filename = "" + filename += f"height_{input_data['height'][0]}" + "---" + filename += f"width_{input_data['width'][0]}" + "---" + filename += f"num_frames_{input_data['num_frames'][0]}" + "---" + filename += prompt_to_filename(input_data["prompt"][0])[:25] + filename += ".mp4" + filename = output_dir.joinpath(filename) + + export_to_video(video, filename, fps=save_fps) + + if accelerator.is_main_process: + print(f"Text-to-video generation for LoRA completed. Results saved in {output_dir}.") + + +if __name__ == "__main__": + fire.Fire(main) From 0f89efc43c30dfa3cbdb4bfa18db43fec4df49f4 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 14 Nov 2024 14:02:01 +0000 Subject: [PATCH 14/14] update --- inference/text_to_video_lora.py | 29 +++++++++++++++++------------ slurm/inference_lora_t2v.slurm | 22 ++++++++++++++++++++++ slurm/vllm_caption.slurm | 1 - 3 files changed, 39 insertions(+), 13 deletions(-) create mode 100644 slurm/inference_lora_t2v.slurm diff --git a/inference/text_to_video_lora.py b/inference/text_to_video_lora.py index edabf9c3..04ecb4eb 100644 --- a/inference/text_to_video_lora.py +++ b/inference/text_to_video_lora.py @@ -1,4 +1,5 @@ import pathlib +import uuid from typing import Any, Dict, Optional import fire @@ -96,6 +97,8 @@ def main( torch_dtype=DTYPE_MAPPING[dtype] ) + print("LoRA ID:", lora_id) + if lora_id is not None: pipe.load_lora_weights(lora_id) @@ -121,24 +124,26 @@ def main( use_dynamic_cfg=False, generator=torch.Generator().manual_seed(seed), ).frames[0] - input_data.append(data.items()) + input_data.append(list(data.items())) video = gather_object(video) input_data = gather_object(input_data) if accelerator.is_main_process: count += 1 - input_data = dict(input_data[0]) - - filename = "" - filename += f"height_{input_data['height'][0]}" + "---" - filename += f"width_{input_data['width'][0]}" + "---" - filename += f"num_frames_{input_data['num_frames'][0]}" + "---" - filename += prompt_to_filename(input_data["prompt"][0])[:25] - filename += ".mp4" - filename = output_dir.joinpath(filename) - - export_to_video(video, filename, fps=save_fps) + for data in input_data: + data = dict(data) + + filename = "" + filename += f"height_{data['height'][0]}" + "---" + filename += f"width_{data['width'][0]}" + "---" + filename += f"num_frames_{data['num_frames'][0]}" + "---" + filename += prompt_to_filename(data["prompt"][0])[:25] + "---" + filename += str(uuid.uuid4()) + filename += ".mp4" + filename = output_dir.joinpath(filename) + + export_to_video(video, filename, fps=save_fps) if accelerator.is_main_process: print(f"Text-to-video generation for LoRA completed. Results saved in {output_dir}.") diff --git a/slurm/inference_lora_t2v.slurm b/slurm/inference_lora_t2v.slurm new file mode 100644 index 00000000..00240c18 --- /dev/null +++ b/slurm/inference_lora_t2v.slurm @@ -0,0 +1,22 @@ +#!/bin/bash +#SBATCH --job-name=cogvideox-lora-inference +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out + +set -xe + +module load cuda/12.1 + +echo "START TIME: $(date)" + +echo "Starting job" + +srun dump_inference_lora_t2v.sh + +echo "END TIME: $(date)" diff --git a/slurm/vllm_caption.slurm b/slurm/vllm_caption.slurm index 456d25c9..6826400c 100755 --- a/slurm/vllm_caption.slurm +++ b/slurm/vllm_caption.slurm @@ -8,7 +8,6 @@ #SBATCH --gres=gpu:8 #SBATCH --partition=hopper-prod #SBATCH -o /path/to/logs/logs-%x-%j.out -#SBATCH --exclusive export VLLM_WORKER_MULTIPROC_METHOD=spawn