diff --git a/Makefile b/Makefile index be649818..d0c7cb1e 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ .PHONY: quality style -check_dirs := training tests +check_dirs := training tests captioning quality: ruff check $(check_dirs) diff --git a/accelerate_configs/uncompiled_8.yaml b/accelerate_configs/uncompiled_8.yaml new file mode 100644 index 00000000..27539bb5 --- /dev/null +++ b/accelerate_configs/uncompiled_8.yaml @@ -0,0 +1,17 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +enable_cpu_affinity: false +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/captioning/dataset_caption.py b/captioning/dataset_caption.py new file mode 100644 index 00000000..6806dee2 --- /dev/null +++ b/captioning/dataset_caption.py @@ -0,0 +1,24 @@ +import pathlib + +import pandas as pd +from torch.utils.data import Dataset + + +class CaptionDataset(Dataset): + def __init__(self, input_file: str) -> None: + self.input_file = pathlib.Path(input_file) + + assert self.input_file.is_file() + + df = pd.read_csv(input_file) + self.filenames = df["filename"] + self.summaries = df["summary"] + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, index: int): + return { + "filename": self.filenames[index], + "summary": self.summaries[index], + } diff --git a/captioning/dataset_video.py b/captioning/dataset_video.py new file mode 100644 index 00000000..671372a5 --- /dev/null +++ b/captioning/dataset_video.py @@ -0,0 +1,98 @@ +import base64 +import io +import os +import sys +from typing import List, Tuple + +from PIL import Image +from torch.utils.data import Dataset + + +import decord # isort:skip + + +class VideoDataset(Dataset): + def __init__( + self, root_video_dir: str, output_dir: str, max_num_frames: int, video_extensions: Tuple[str] = (".mp4") + ): + self.root_video_dir = root_video_dir + self.max_num_frames = max_num_frames + + # Filter out existing captions. + video_files = { + os.path.join(root_video_dir, f) for f in os.listdir(root_video_dir) if f.endswith(video_extensions) + } + if os.path.isdir(output_dir): + existing_caption_basenames = { + os.path.splitext(f)[0] for f in os.listdir(output_dir) if "_caption.txt" in f + } + else: + existing_caption_basenames = None + if existing_caption_basenames: + if len(existing_caption_basenames) == len(video_files): + sys.exit( + "It seems like all the input videos have been already captioned. So, we're exiting the program." + ) + filtered_video_files = [ + f + for f in video_files + if os.path.splitext(os.path.basename(f))[0] + "_caption" not in existing_caption_basenames + ] + if len(video_files) > len(filtered_video_files): + diff = len(video_files) - len(filtered_video_files) + print(f"Found existing captions for {diff} videos. Will skip them.") + + self.video_files = sorted(filtered_video_files) + else: + self.video_files = sorted(video_files) + print(f"Total videos found: {len(self.video_files)}.") + + def __len__(self) -> int: + return len(self.video_files) + + def __getitem__(self, index: int) -> List[Image.Image]: + video_path = self.video_files[index] + return self.load_video(video_path) + + def load_video(self, path: str) -> List[Image.Image]: + video_reader = decord.VideoReader(uri=path) + filename = os.path.basename(path) + + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + sample_fps = round(video_reader.get_avg_fps() / 1) + frame_idx = list(range(0, len(video_reader), sample_fps)) + + if len(frame_idx) > self.max_num_frames: + frame_idx = uniform_sample(frame_idx, self.max_num_frames) + + video_frames = [Image.fromarray(video_reader[i].asnumpy()) for i in frame_idx] + + return {"video": [self.encode_image(frame) for frame in video_frames], "filename": filename} + + def encode_image(self, image): + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + image_bytes = buffered.getvalue() + return base64.b64encode(image_bytes).decode("utf-8") + + +if __name__ == "__main__": + import tempfile + + from huggingface_hub import snapshot_download + + with tempfile.TemporaryDirectory() as tmpdirname: + video_root_dir = snapshot_download( + repo_id="Wild-Heart/Disney-VideoGeneration-Dataset", repo_type="dataset", local_dir=tmpdirname + ) + + dataset = VideoDataset(os.path.join(video_root_dir, "videos"), max_num_frames=16) + print(len(dataset)) + + for item in dataset: + print(len(item["video"])) + break diff --git a/captioning/launch.sh b/captioning/launch.sh new file mode 100755 index 00000000..aafb7b2e --- /dev/null +++ b/captioning/launch.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +export CUDA_VISIBLE_DEVICES=0,1 + +# Path to where vLLM models should be downloaded +DOWNLOAD_DIR="/path/to/download/dir" + +# Path to where video files are located +ROOT_DIR="/path/to/video/files" + +# Path to where captions should be stored +OUTPUT_DIR="/path/to/save/captions" + +# Other configurations +MAX_FRAMES=8 +MAX_TOKENS=120 +BATCH_SIZE=2 +NUM_DATA_WORKERS=4 +NUM_ARTIFACT_WORKERS=4 + +PROMPT="Please describe the content of this video in as much detail as possible, including the objects, scenery, animals, characters, and camera movements within the video. Do not include '\n' in your response. Please start the description with the video content directly. Please describe the content of the video and the changes that occur, in chronological order." + +python recaption.py \ + --root_dir $ROOT_DIR \ + --output_dir $OUTPUT_DIR \ + --num_devices 1 \ + --max_num_frames $MAX_FRAMES \ + --max_tokens $MAX_TOKENS \ + --num_data_workers $NUM_DATA_WORKERS \ + --batch_size $BATCH_SIZE \ + --prompt $PROMPT \ + --num_artifact_workers $NUM_ARTIFACT_WORKERS \ + --download_dir $DOWNLOAD_DIR diff --git a/captioning/vllm_caption.py b/captioning/vllm_caption.py new file mode 100644 index 00000000..234403f9 --- /dev/null +++ b/captioning/vllm_caption.py @@ -0,0 +1,169 @@ +""" +Needs `vllm` to be installed from the `main`. +""" + +import os +import queue +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import fire +from torch.utils.data import DataLoader +from vllm import LLM, SamplingParams + + +from dataset_caption import CaptionDataset # isort:skip + + +SYSTEM_PROMPT = r""" +You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. +For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. +There are a few rules to follow: +- You will only ever output a single video description per request. Do not use newlines. +- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. +You responses should just be the video generation prompt. Here are examples: +- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." +- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" +""".strip() + +PROMPT_GEN_USER_PROMPT = r""" +Could you generate a prompt for a video generation model given the following summary: + +``` +{0} +``` + +Please limit the prompt to [{1}] words. +""".strip() + + +def save_results(output_queue, output_dir): + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + + video_filenames, captions = item + + with open(os.path.join(output_dir, "videos.txt"), "a", encoding="utf-8") as file: + for filename in video_filenames: + file.write(filename + "\n") + + with open(os.path.join(output_dir, "prompts.txt"), "a", encoding="utf-8") as file: + for caption in captions: + file.write(caption + "\n") + except queue.Empty: + continue + + +def create_prompt_generation_conversations(batch, prompt: Optional[str] = None): + if prompt is None: + prompt = PROMPT_GEN_USER_PROMPT + + conversations = [] + + for i, summary in enumerate(batch): + conversation = [] + content = [] + + content.append({"type": "text", "text": prompt.format(summary, 50)}) + + conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) + conversation.append({"role": "user", "content": content}) + + conversations.append(conversation) + + return conversations + + +def collate_fn(batch): + inputs = { + "summary": [sample["summary"] for sample in batch], + "filename": [sample["filename"] for sample in batch], + } + return inputs + + +def prepare_dataloader(input_file, num_data_workers, batch_size): + dataset = CaptionDataset(input_file) + + # Create DataLoader + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_data_workers, + pin_memory=True, + persistent_workers=True, + ) + return dataloader + + +def load_prompt_gen_model( + max_tokens: int, + num_devices: int, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + engine = LLM( + "meta-llama/Meta-Llama-3.1-8B-Instruct", + dtype="bfloat16", + tensor_parallel_size=num_devices, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + sampling_params = SamplingParams(max_tokens=max_tokens) + return engine, sampling_params + + +def main( + input_file: str, + output_dir: str, + num_devices: int = 1, + max_prompt_gen_tokens: int = 256, + prompt_gen_prompt: Optional[str] = None, + batch_size: int = 8, + num_data_workers: int = 4, + num_artifact_workers: int = 4, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + prompt_gen_engine, prompt_gen_sampling_params = load_prompt_gen_model( + max_tokens=max_prompt_gen_tokens, + num_devices=num_devices, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + + dataloader = prepare_dataloader( + input_file=input_file, + num_data_workers=num_data_workers, + batch_size=batch_size, + ) + + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=num_artifact_workers) + os.makedirs(output_dir, exist_ok=True) + save_future = save_thread.submit(save_results, output_queue, output_dir) + + try: + for idx, batch in enumerate(dataloader): + conversations = create_prompt_generation_conversations(batch["summary"], prompt=prompt_gen_prompt) + prompts = prompt_gen_engine.chat(conversations, prompt_gen_sampling_params) + + # Get outputs and remove surrounding quotes/newlines + prompts = [" ".join(prompt.outputs[0].text.split("\n")) for prompt in prompts] + prompts = [prompt.lstrip('"').rstrip('"') for prompt in prompts] + + output_queue.put((batch["filename"], prompts)) + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) + + save_future.result() + print("All processes completed. Caption generation and saving done.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/captioning/vllm_summary.py b/captioning/vllm_summary.py new file mode 100644 index 00000000..5606f80e --- /dev/null +++ b/captioning/vllm_summary.py @@ -0,0 +1,191 @@ +""" +Needs `vllm` to be installed from the `main`. +""" + +import os +import pathlib +import queue +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import fire +import pandas as pd +from torch.utils.data import DataLoader +from vllm import LLM, SamplingParams + + +from dataset_video import VideoDataset # isort:skip + + +SYSTEM_PROMPT = r""" +You are part of a team of people that create videos using generative models. You use a video-generation model that can generate a video about anything you describe. +For example, if you respond with "A beautiful morning in the woods with the sun peaking through the trees", the video generation model will create a video of exactly as described. You task is to summarize the descriptions of videos provided to by users, and create details prompts to feed into the generative model. +There are a few rules to follow: +- You will only ever output a single video description per request. +- If the user mentions to summarize the prompt in [X] words, make sure to not exceed the limit. +You responses should just be the video generation prompt. Here are examples: +- "A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting." +- "A street artist, clad in a worn-out denim jacket and a colorful bandana, stands before a vast concrete wall in the heart, holding a can of spray paint, spray-painting a colorful bird on a mottled wall" +""".strip() + +SUMMARY_USER_PROMPT = r"""Please summarize this video and limit the summary to 100-200 words.""".strip() + +PROMPT_GEN_USER_PROMPT = r""" +Could you generate a prompt for a video generation model given the following summary: + +``` +{0} +``` + +Please limit the prompt to [{1}] words. +""".strip() + + +def save_results(output_queue, output_dir): + output_file = pathlib.Path(output_dir).joinpath("summary.csv") + + while True: + try: + item = output_queue.get(timeout=5) + if item is None: + break + + video_filenames, summaries = item + + df_data = [] + for filename, summary in zip(video_filenames, summaries): + df_data.append( + { + "filename": filename, + "summary": summary, + } + ) + + df = pd.DataFrame(df_data) + df.to_csv(output_file, mode="a", header=not output_file.exists(), index=False) + + except queue.Empty: + continue + + +def create_video_summary_conversations(batch, prompt: Optional[str] = None): + if prompt is None: + prompt = SUMMARY_USER_PROMPT + + conversations = [] + + for i, video in enumerate(batch["videos"]): + conversation = [] + content = [] + + content.append({"type": "text", "text": prompt}) + for frame in video: + new_image = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{frame}"}} + content.append(new_image) + + # conversation.append({"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}) + conversation.append({"role": "user", "content": content}) + + conversations.append(conversation) + + return conversations + + +def collate_fn(batch): + inputs = { + "videos": [sample["video"] for sample in batch], + "filename": [sample["filename"] for sample in batch], + } + return inputs + + +def prepare_dataloader(video_root_dir, output_dir, video_extensions, max_num_frames, num_data_workers, batch_size): + dataset = VideoDataset( + video_root_dir, output_dir=output_dir, max_num_frames=max_num_frames, video_extensions=video_extensions + ) + + # Create DataLoader + dataloader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_data_workers, + pin_memory=True, + persistent_workers=True, + ) + return dataloader + + +def load_summary_model( + max_num_frames: int, + max_tokens: int, + num_devices: int, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + engine = LLM( + "openbmb/MiniCPM-V-2_6", + dtype="bfloat16", + tensor_parallel_size=num_devices, + limit_mm_per_prompt={"image": max_num_frames}, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + sampling_params = SamplingParams(min_tokens=64, max_tokens=max_tokens) + return engine, sampling_params + + +def main( + root_dir: str, + output_dir: str, + num_devices: int = 1, + max_num_frames: int = 8, + max_summary_tokens: int = 512, + video_summary_prompt: Optional[str] = None, + video_extensions: tuple = (".mp4"), + batch_size: int = 8, + num_data_workers: int = 4, + num_artifact_workers: int = 4, + download_dir: Optional[str] = None, + trust_remote_code: bool = False, +): + max_allowed_imgs_per_req = batch_size * max_num_frames + + summary_engine, summary_sampling_params = load_summary_model( + max_num_frames=max_allowed_imgs_per_req, + max_tokens=max_summary_tokens, + num_devices=num_devices, + download_dir=download_dir, + trust_remote_code=trust_remote_code, + ) + + dataloader = prepare_dataloader( + video_root_dir=root_dir, + output_dir=output_dir, + video_extensions=video_extensions, + max_num_frames=max_num_frames, + num_data_workers=num_data_workers, + batch_size=batch_size, + ) + + output_queue = queue.Queue() + save_thread = ThreadPoolExecutor(max_workers=num_artifact_workers) + os.makedirs(output_dir, exist_ok=True) + save_future = save_thread.submit(save_results, output_queue, output_dir) + + try: + for idx, batch in enumerate(dataloader): + conversations = create_video_summary_conversations(batch, prompt=video_summary_prompt) + video_summaries = summary_engine.chat(conversations, summary_sampling_params) + summaries = [summary.outputs[0].text for summary in video_summaries] + output_queue.put((batch["filename"], summaries)) + finally: + output_queue.put(None) + save_thread.shutdown(wait=True) + + save_future.result() + print("All processes completed. Caption generation and saving done.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/data/cakify_eval_prompts.csv b/data/cakify_eval_prompts.csv new file mode 100644 index 00000000..8d556d54 --- /dev/null +++ b/data/cakify_eval_prompts.csv @@ -0,0 +1,27 @@ +prompt,height,width,num_frames +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,480,17 +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,480,49 +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,720,17 +"{id_token} A cake shaped like a Nutella container is carefully sliced, revealing a light interior, amidst a Nutella-themed setup, showcasing deliberate cutting and preserved details for an appetizing dessert presentation on a white base with accompanying jello and cutlery, highlighting culinary skills and creative cake designs.",480,720,49 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,17 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,49 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,720,17 +"A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,720,49 +"{id_token} A vibrant orange cake disguised as a Nike packaging box sits on a dark surface, meticulous in its detail and design, complete with a white swoosh and 'NIKE' logo. A person's hands, holding a knife, hover over the cake, ready to make a precise cut, amidst a simple and clean background.",480,480,17 +"{id_token} A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,33 +"{id_token} A Nutella-shaped cake is sliced open to reveal a light, moist interior. Set against a Nutella-themed backdrop, the cake sits on a white base with jello and cutlery, emphasizing precise craftsmanship, appetizing details, and creative dessert presentation.",480,480,49 +"{id_token} A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.",768,720,17 +"{id_token} A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.",768,720,33 +"{id_token} A person with gloved hands carefully cuts a cake shaped like a Skittles bottle, beginning with a precise incision at the lid, followed by careful sequential cuts around the neck, eventually detaching the lid from the body, revealing the chocolate interior of the cake while showcasing the layered design's detail.",768,720,49 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,480,49 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,720,49 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,480,33 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,720,33 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,480,17 +"{id_token} A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance",480,720,17 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,480,49 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,720,49 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,480,33 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,720,33 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,480,17 +"{id_token} A panda in a red jacket and tiny hat sits on a wooden stool in a tranquil bamboo forest, strumming a miniature guitar. Other pandas gather nearby, some clapping along as sunlight filters through the bamboo, casting a gentle glow. The panda's face shows joy and focus, adding to the peaceful, magical atmosphere with a small stream and lush foliage in the background.",480,720,17 diff --git a/inference/text_to_video_lora.py b/inference/text_to_video_lora.py new file mode 100644 index 00000000..04ecb4eb --- /dev/null +++ b/inference/text_to_video_lora.py @@ -0,0 +1,153 @@ +import pathlib +import uuid +from typing import Any, Dict, Optional + +import fire +import pandas as pd +import torch +from accelerate import Accelerator +from accelerate.utils import gather_object +from diffusers import CogVideoXPipeline +from diffusers.utils import export_to_video +from torch.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm + + +class PromptDataset(Dataset): + def __init__(self, filename: str, id_token: Optional[str] = None) -> None: + super().__init__() + + self.id_token = id_token or "" + + df = pd.read_csv(filename) + + self.prompts = df["prompt"] + self.heights = df["height"] + self.widths = df["width"] + self.num_frames = df["num_frames"] + + def __len__(self): + return len(self.prompts) + + def __getitem__(self, index): + prompt = self.prompts[index].format(id_token=self.id_token).strip() + return { + "prompt": prompt, + "height": self.heights[index], + "width": self.widths[index], + "num_frames": self.num_frames[index], + } + + +class CollateFunction: + def __call__(self, data: Dict[str, Any]) -> Dict[str, torch.Tensor]: + prompt = [x["prompt"] for x in data] + height = [x["height"] for x in data] + width = [x["width"] for x in data] + num_frames = [x["num_frames"] for x in data] + + return { + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + } + + +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} + + +def prompt_to_filename(x: str) -> str: + for c in ["\"", "'", ",", "\\", " "]: + x = x.replace(c, "-") + return x + + +def main( + dataset_file: str, + model_id: str = "THUDM/CogVideoX-5b", + lora_id: Optional[str] = None, + id_token: Optional[str] = None, + dtype: str = "bf16", + enable_model_cpu_offload: bool = False, + output_dir: str = "text_to_video_lora_outputs", + save_fps: int = 8, + seed: int = 42, +) -> None: + dataset = PromptDataset(dataset_file, id_token) + + collate_fn = CollateFunction() + dataloader = DataLoader( + dataset=dataset, + batch_size=1, + collate_fn=collate_fn, + shuffle=False, + num_workers=1, + ) + + output_dir: pathlib.Path = pathlib.Path(output_dir) + output_dir.mkdir(exist_ok=True) + + pipe = CogVideoXPipeline.from_pretrained( + model_id, + torch_dtype=DTYPE_MAPPING[dtype] + ) + + print("LoRA ID:", lora_id) + + if lora_id is not None: + pipe.load_lora_weights(lora_id) + + accelerator = Accelerator() + + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=accelerator.device.index) + else: + pipe = pipe.to(accelerator.device) + + count = 0 + for _, data_raw in tqdm(enumerate(dataloader), total=len(dataloader)): + input_data = [] + + with accelerator.split_between_processes(data_raw) as data: + video = pipe( + prompt=data["prompt"][0], + height=data["height"][0], + width=data["width"][0], + num_frames=data["num_frames"][0], + num_inference_steps=50, + guidance_scale=6.0, + use_dynamic_cfg=False, + generator=torch.Generator().manual_seed(seed), + ).frames[0] + input_data.append(list(data.items())) + + video = gather_object(video) + input_data = gather_object(input_data) + + if accelerator.is_main_process: + count += 1 + for data in input_data: + data = dict(data) + + filename = "" + filename += f"height_{data['height'][0]}" + "---" + filename += f"width_{data['width'][0]}" + "---" + filename += f"num_frames_{data['num_frames'][0]}" + "---" + filename += prompt_to_filename(data["prompt"][0])[:25] + "---" + filename += str(uuid.uuid4()) + filename += ".mp4" + filename = output_dir.joinpath(filename) + + export_to_video(video, filename, fps=save_fps) + + if accelerator.is_main_process: + print(f"Text-to-video generation for LoRA completed. Results saved in {output_dir}.") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/slurm/inference_lora_t2v.slurm b/slurm/inference_lora_t2v.slurm new file mode 100644 index 00000000..00240c18 --- /dev/null +++ b/slurm/inference_lora_t2v.slurm @@ -0,0 +1,22 @@ +#!/bin/bash +#SBATCH --job-name=cogvideox-lora-inference +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out + +set -xe + +module load cuda/12.1 + +echo "START TIME: $(date)" + +echo "Starting job" + +srun dump_inference_lora_t2v.sh + +echo "END TIME: $(date)" diff --git a/slurm/prepare_dataset.slurm b/slurm/prepare_dataset.slurm new file mode 100644 index 00000000..fba162aa --- /dev/null +++ b/slurm/prepare_dataset.slurm @@ -0,0 +1,52 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +MODEL_ID="THUDM/CogVideoX-5b" +NUM_GPUS=8 + +DATA_ROOT="/path/to/datasets/dataset-cakify-yt" +CAPTION_COLUMN="prompts.txt" +VIDEO_COLUMN="videos.txt" +OUTPUT_DIR="/path/to/datasets/dataset-cakify-yt-encoded" +HEIGHT_BUCKETS="480 720 768" +WIDTH_BUCKETS="480 720 768" +FRAME_BUCKETS="17 33 49" +MAX_NUM_FRAMES=49 +MAX_SEQUENCE_LENGTH=226 +TARGET_FPS=8 +BATCH_SIZE=1 +DTYPE=fp32 +ADDITIONAL_FLAGS="--save_image_latents --save_latents_and_embeddings" + +set -xe + +module load cuda/12.1 + +echo "Starting job" + +srun torchrun --nproc_per_node=$NUM_GPUS training/prepare_dataset.py \ + --model_id $MODEL_ID \ + --data_root $DATA_ROOT \ + --caption_column $CAPTION_COLUMN \ + --video_column $VIDEO_COLUMN \ + --output_dir $OUTPUT_DIR \ + --height_buckets $HEIGHT_BUCKETS \ + --width_buckets $WIDTH_BUCKETS \ + --frame_buckets $FRAME_BUCKETS \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_sequence_length $MAX_SEQUENCE_LENGTH \ + --target_fps $TARGET_FPS \ + --batch_size $BATCH_SIZE \ + --dtype $DTYPE \ + $ADDITIONAL_FLAGS + +echo "End time: $(date)" diff --git a/slurm/train_lora_t2v.slurm b/slurm/train_lora_t2v.slurm new file mode 100644 index 00000000..5d1d9209 --- /dev/null +++ b/slurm/train_lora_t2v.slurm @@ -0,0 +1,28 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=168:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +set -xe + +echo "START TIME: $(date)" + +# Show some environment variables +echo python3 version = `python3 --version` +echo "NCCL version: $(python -c "import torch;print(torch.cuda.nccl.version())")" +echo "CUDA version: $(python -c "import torch;print(torch.version.cuda)")" + +module load cuda/12.1 + +echo "Starting job" + +srun dump_train_text_to_video_lora.sh + +echo "END TIME: $(date)" diff --git a/slurm/vllm_caption.slurm b/slurm/vllm_caption.slurm new file mode 100755 index 00000000..6826400c --- /dev/null +++ b/slurm/vllm_caption.slurm @@ -0,0 +1,42 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out + +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +INPUT_FILE="/path/to/datasets/dataset-cakify-yt/summary.csv" +OUTPUT_DIR="/path/to/datasets/dataset-cakify-yt/" +NUM_DEVICES=8 +BATCH_SIZE=4 +MAX_PROMPT_GEN_TOKENS=256 +NUM_DATA_WORKERS=8 +NUM_ARTIFACT_WORKERS=8 +ADDITIONAL_FLAGS="--trust_remote_code" + +PROMPT_GEN_PROMPT="Could you generate a prompt for a text-to-video generation model given the following summary:\n\n\`\`\`\n{0}\n\`\`\`\n\nPlease limit the prompt to [{1}] words. To provide some additional context, these summaries are generated for videos that depict cutting of tasty cakes disguised as a realistic looking object." + +set -xe + +module load cuda/12.1 + +echo "Starting job" + +srun python3 captioning/vllm_caption.py \ + --input_file $INPUT_FILE \ + --output_dir $OUTPUT_DIR \ + --num_devices $NUM_DEVICES \ + --max_prompt_gen_tokens $MAX_PROMPT_GEN_TOKENS \ + --prompt_gen_prompt "$PROMPT_GEN_PROMPT" \ + --batch_size $BATCH_SIZE \ + --num_data_workers $NUM_DATA_WORKERS \ + --num_artifact_workers $NUM_ARTIFACT_WORKERS \ + $ADDITIONAL_FLAGS + +echo "End time: $(date)" diff --git a/slurm/vllm_summary.slurm b/slurm/vllm_summary.slurm new file mode 100755 index 00000000..2260bb03 --- /dev/null +++ b/slurm/vllm_summary.slurm @@ -0,0 +1,45 @@ +#!/bin/bash +#SBATCH --job-name=recaptioning +#SBATCH --nodes=1 +#SBATCH --qos=normal +#SBATCH --time=12:00:00 +#SBATCH --requeue +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --partition=hopper-prod +#SBATCH -o /path/to/logs/logs-%x-%j.out +#SBATCH --exclusive + +export VLLM_WORKER_MULTIPROC_METHOD=spawn + +ROOT_DIR="/path/to/datasets/dataset-cakify-yt/videos" +OUTPUT_DIR="/path/to/datasets/dataset-cakify-yt/" +NUM_DEVICES=4 +BATCH_SIZE=4 +MAX_NUM_FRAMES=16 +MAX_SUMMARY_TOKENS=2048 +NUM_DATA_WORKERS=8 +NUM_ARTIFACT_WORKERS=8 +ADDITIONAL_FLAGS="--trust_remote_code" + +VIDEO_SUMMARY_PROMPT="Summarize this video and limit to 500 words. The sequence of video frames are of a person cutting cakes shaped as different realistic objects. Make sure to mention about the cake cutting in the summary." + +set -xe + +module load cuda/12.1 + +echo "Starting job" + +srun python3 captioning/vllm_summary.py \ + --root_dir $ROOT_DIR \ + --output_dir $OUTPUT_DIR \ + --num_devices $NUM_DEVICES \ + --max_num_frames $MAX_NUM_FRAMES \ + --max_summary_tokens $MAX_SUMMARY_TOKENS \ + --video_summary_prompt "$VIDEO_SUMMARY_PROMPT" \ + --batch_size $BATCH_SIZE \ + --num_data_workers $NUM_DATA_WORKERS \ + --num_artifact_workers $NUM_ARTIFACT_WORKERS \ + $ADDITIONAL_FLAGS + +echo "End time: $(date)" diff --git a/tests/test_lora_inference.py b/tests/test_lora_inference.py index 31428a4b..d1c0439b 100644 --- a/tests/test_lora_inference.py +++ b/tests/test_lora_inference.py @@ -8,6 +8,7 @@ """ import argparse + import torch from diffusers import CogVideoXPipeline from diffusers.utils import export_to_video diff --git a/training/cogvideox_text_to_video_lora.py b/training/cogvideox_text_to_video_lora.py index cf678a5a..2e071440 100644 --- a/training/cogvideox_text_to_video_lora.py +++ b/training/cogvideox_text_to_video_lora.py @@ -62,7 +62,7 @@ print_memory, reset_memory, unwrap_model, -) # isort:skip +) logger = get_logger(__name__) diff --git a/training/cogvideox_text_to_video_sft.py b/training/cogvideox_text_to_video_sft.py index 9b442175..01563019 100644 --- a/training/cogvideox_text_to_video_sft.py +++ b/training/cogvideox_text_to_video_sft.py @@ -61,7 +61,7 @@ print_memory, reset_memory, unwrap_model, -) # isort:skip +) logger = get_logger(__name__) diff --git a/training/dataset.py b/training/dataset.py index dca375df..fa38de85 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -277,7 +277,9 @@ def _preprocess_video(self, path: Path) -> torch.Tensor: video_reader = decord.VideoReader(uri=path.as_posix()) video_num_frames = len(video_reader) nearest_frame_bucket = min( - self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)) + [bucket for bucket in self.frame_buckets if bucket <= video_num_frames], + key=lambda x: abs(x - min(video_num_frames, self.max_num_frames)), + default=1, ) frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))