diff --git a/examples/training/finetune/hunyuan_t2v/preprocess_hunyuan_data_t2v.sh b/examples/training/finetune/hunyuan_t2v/preprocess_hunyuan_data_t2v.sh new file mode 100755 index 000000000..d0d9487ae --- /dev/null +++ b/examples/training/finetune/hunyuan_t2v/preprocess_hunyuan_data_t2v.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +GPU_NUM=1 # 2,4,8 +MODEL_PATH="hunyuanvideo-community/HunyuanVideo" +DATASET_PATH="data/crush-smol" +OUTPUT_DIR="data/crush-smol_processed_t2v_hunyuan/" + +torchrun --nproc_per_node=$GPU_NUM \ + -m fastvideo.pipelines.preprocess.v1_preprocessing_new \ + --model_path $MODEL_PATH \ + --mode preprocess \ + --workload_type t2v \ + --preprocess.dataset_type merged \ + --preprocess.dataset_path $DATASET_PATH \ + --preprocess.dataset_output_dir $OUTPUT_DIR \ + --preprocess.preprocess_video_batch_size 2 \ + --preprocess.dataloader_num_workers 0 \ + --preprocess.max_height 480 \ + --preprocess.max_width 832 \ + --preprocess.num_frames 77 \ + --preprocess.train_fps 16 \ + --preprocess.samples_per_file 8 \ + --preprocess.flush_frequency 8 \ + --preprocess.video_length_tolerance_range 5 + diff --git a/fastvideo/configs/models/encoders/clip.py b/fastvideo/configs/models/encoders/clip.py index a7d313a86..b8b4942e1 100644 --- a/fastvideo/configs/models/encoders/clip.py +++ b/fastvideo/configs/models/encoders/clip.py @@ -74,7 +74,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig): class CLIPTextConfig(TextEncoderConfig): arch_config: TextEncoderArchConfig = field( default_factory=CLIPTextArchConfig) - + tokenizer_kwargs: dict = field( + default_factory=lambda: { + "padding": "max_length", + "truncation": True, + "max_length": 77, + "return_tensors": "pt" + }) num_hidden_layers_override: int | None = None require_post_norm: bool | None = None prefix: str = "clip" diff --git a/fastvideo/configs/models/encoders/llama.py b/fastvideo/configs/models/encoders/llama.py index 53fc21e74..988a83f94 100644 --- a/fastvideo/configs/models/encoders/llama.py +++ b/fastvideo/configs/models/encoders/llama.py @@ -60,5 +60,11 @@ class LlamaArchConfig(TextEncoderArchConfig): @dataclass class LlamaConfig(TextEncoderConfig): arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig) - + tokenizer_kwargs: dict = field( + default_factory=lambda: { + "padding": "max_length", + "truncation": True, + "max_length": 256, + "return_tensors": "pt" + }) prefix: str = "llama" diff --git a/fastvideo/layers/rotary_embedding.py b/fastvideo/layers/rotary_embedding.py index 6abe90609..24c9d7f62 100644 --- a/fastvideo/layers/rotary_embedding.py +++ b/fastvideo/layers/rotary_embedding.py @@ -138,14 +138,14 @@ def forward_native( cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) + query = query.reshape(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) + key = key.reshape(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) diff --git a/fastvideo/models/vaes/hunyuanvae.py b/fastvideo/models/vaes/hunyuanvae.py index d0f614ea3..a2eb8d48b 100644 --- a/fastvideo/models/vaes/hunyuanvae.py +++ b/fastvideo/models/vaes/hunyuanvae.py @@ -361,7 +361,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states.device, batch_size=batch_size) hidden_states = attn(hidden_states, - attention_mask=attention_mask) + attention_mask=attention_mask.unsqueeze(1)) hidden_states = hidden_states.unflatten( 1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) @@ -385,7 +385,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states.device, batch_size=batch_size) hidden_states = attn(hidden_states, - attention_mask=attention_mask) + attention_mask=attention_mask.unsqueeze(1)) hidden_states = hidden_states.unflatten( 1, (num_frames, height, width)).permute(0, 4, 1, 2, 3) diff --git a/fastvideo/pipelines/preprocess/hunyuan/__init__.py b/fastvideo/pipelines/preprocess/hunyuan/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastvideo/pipelines/preprocess/hunyuan/hunyuan_preprocess_pipelines.py b/fastvideo/pipelines/preprocess/hunyuan/hunyuan_preprocess_pipelines.py new file mode 100644 index 000000000..7245be8c7 --- /dev/null +++ b/fastvideo/pipelines/preprocess/hunyuan/hunyuan_preprocess_pipelines.py @@ -0,0 +1,110 @@ +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.preprocess.preprocess_stages import ( + TextTransformStage, VideoTransformStage) +from fastvideo.pipelines.stages import (EncodingStage, ImageEncodingStage, + TextEncodingStage) +from fastvideo.pipelines.stages.image_encoding import ImageVAEEncodingStage + + +class PreprocessPipelineI2V(ComposedPipelineBase): + _required_config_modules = [ + "image_encoder", "image_processor", "text_encoder", "tokenizer", + "text_encoder_2", "tokenizer_2", "vae" + ] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + assert fastvideo_args.preprocess_config is not None + self.add_stage(stage_name="text_transform_stage", + stage=TextTransformStage( + cfg_uncondition_drop_rate=fastvideo_args. + preprocess_config.training_cfg_rate, + seed=fastvideo_args.preprocess_config.seed, + )) + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + self.add_stage( + stage_name="video_transform_stage", + stage=VideoTransformStage( + train_fps=fastvideo_args.preprocess_config.train_fps, + num_frames=fastvideo_args.preprocess_config.num_frames, + max_height=fastvideo_args.preprocess_config.max_height, + max_width=fastvideo_args.preprocess_config.max_width, + do_temporal_sample=fastvideo_args.preprocess_config. + do_temporal_sample, + )) + if (self.get_module("image_encoder") is not None + and self.get_module("image_processor") is not None): + self.add_stage( + stage_name="image_encoding_stage", + stage=ImageEncodingStage( + image_encoder=self.get_module("image_encoder"), + image_processor=self.get_module("image_processor"), + )) + self.add_stage(stage_name="image_vae_encoding_stage", + stage=ImageVAEEncodingStage( + vae=self.get_module("vae"), )) + self.add_stage(stage_name="video_encoding_stage", + stage=EncodingStage(vae=self.get_module("vae"), )) + + +class PreprocessPipelineT2V(ComposedPipelineBase): + _required_config_modules = [ + "text_encoder", "tokenizer", "text_encoder_2", "tokenizer_2", "vae" + ] + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + assert fastvideo_args.preprocess_config is not None + self.add_stage(stage_name="text_transform_stage", + stage=TextTransformStage( + cfg_uncondition_drop_rate=fastvideo_args. + preprocess_config.training_cfg_rate, + seed=fastvideo_args.preprocess_config.seed, + )) + # llama_tokenizer_kwargs = { + # "padding": "max_length", + # "truncation": True, + # "max_length": 256, + # "return_tensors": "pt" + # } + # clip_tokenizer_kwargs = { + # "padding": "max_length", + # "truncation": True, + # "max_length": 77, + # "return_tensors": "pt" + # } + # if len(fastvideo_args.pipeline_config.text_encoder_configs) >= 2: + # fastvideo_args.pipeline_config.text_encoder_configs[0].tokenizer_kwargs = llama_tokenizer_kwargs + # fastvideo_args.pipeline_config.text_encoder_configs[1].tokenizer_kwargs = clip_tokenizer_kwargs + text_encoders = [ + self.get_module("text_encoder"), + self.get_module("text_encoder_2") + ] + tokenizers = [ + self.get_module("tokenizer"), + self.get_module("tokenizer_2") + ] + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=text_encoders, + tokenizers=tokenizers, + )) + self.add_stage( + stage_name="video_transform_stage", + stage=VideoTransformStage( + train_fps=fastvideo_args.preprocess_config.train_fps, + num_frames=fastvideo_args.preprocess_config.num_frames, + max_height=fastvideo_args.preprocess_config.max_height, + max_width=fastvideo_args.preprocess_config.max_width, + do_temporal_sample=fastvideo_args.preprocess_config. + do_temporal_sample, + )) + self.add_stage(stage_name="video_encoding_stage", + stage=EncodingStage(vae=self.get_module("vae"), )) + + +EntryClass = [PreprocessPipelineI2V, PreprocessPipelineT2V] diff --git a/fastvideo/workflow/preprocess/components.py b/fastvideo/workflow/preprocess/components.py index aa9bdbbcd..dd86ba7d3 100644 --- a/fastvideo/workflow/preprocess/components.py +++ b/fastvideo/workflow/preprocess/components.py @@ -14,7 +14,6 @@ from datasets import Dataset, Video, load_dataset from fastvideo.configs.configs import DatasetType, PreprocessConfig -from fastvideo.distributed.parallel_state import get_world_rank, get_world_size from fastvideo.logger import init_logger from fastvideo.pipelines.pipeline_batch_info import PreprocessBatch @@ -79,8 +78,8 @@ def __call__(self, batch: dict[str, Any]) -> bool: def _validate_data_type(self, batch: dict[str, Any]) -> bool: """Validate basic validity of data items""" - return not (batch["caption"] is None or batch["caption"] == "" - or batch["fps"] is None or batch["fps"] <= 0 + return not (batch["caption"] is None or batch["caption"] == "" or "fps" + not in batch or batch["fps"] is None or batch["fps"] <= 0 or batch["num_frames"] is None or batch["num_frames"] <= 0) def _validate_resolution(self, batch: dict[str, Any]) -> bool: @@ -400,13 +399,9 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int: return written_count -def build_dataset(preprocess_config: PreprocessConfig, split: str, - validator: Callable[[dict[str, Any]], bool]) -> Dataset: +def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset: if preprocess_config.dataset_type == DatasetType.HF: dataset = load_dataset(preprocess_config.dataset_path, split=split) - dataset = dataset.filter(validator) - dataset = dataset.shard(num_shards=get_world_size(), - index=get_world_rank()) elif preprocess_config.dataset_type == DatasetType.MERGED: metadata_json_path = os.path.join(preprocess_config.dataset_path, "videos2caption.json") @@ -420,11 +415,6 @@ def build_dataset(preprocess_config: PreprocessConfig, split: str, dataset = dataset.rename_column("cap", "caption") if "path" in column_names: dataset = dataset.rename_column("path", "name") - - dataset = dataset.filter(validator) - dataset = dataset.shard(num_shards=get_world_size(), - index=get_world_rank()) - # add video column def add_video_column(item: dict[str, Any]) -> dict[str, Any]: item["video"] = os.path.join(video_folder, item["name"]) diff --git a/fastvideo/workflow/preprocess/preprocess_workflow.py b/fastvideo/workflow/preprocess/preprocess_workflow.py index 8f83f07db..0146098ae 100644 --- a/fastvideo/workflow/preprocess/preprocess_workflow.py +++ b/fastvideo/workflow/preprocess/preprocess_workflow.py @@ -44,9 +44,9 @@ def register_components(self) -> None: self.add_component("raw_data_validator", raw_data_validator) # training dataset - training_dataset = build_dataset(preprocess_config, - split="train", - validator=raw_data_validator) + training_dataset = build_dataset(preprocess_config, split="train") + # set load_from_cache_file to False to check filter stats + training_dataset = training_dataset.filter(raw_data_validator) # we do not use collate_fn here because we use iterable-style Dataset # and want to keep the original type of the dataset training_dataloader = DataLoader( @@ -60,8 +60,8 @@ def register_components(self) -> None: # try to load validation dataset if it exists try: validation_dataset = build_dataset(preprocess_config, - split="validation", - validator=raw_data_validator) + split="validation") + validation_dataset = validation_dataset.filter(raw_data_validator) validation_dataloader = DataLoader( validation_dataset, batch_size=preprocess_config.preprocess_video_batch_size, diff --git a/scripts/dataset_preparation/prepare_json_file.py b/scripts/dataset_preparation/prepare_json_file.py index b263b7053..7b67d7951 100644 --- a/scripts/dataset_preparation/prepare_json_file.py +++ b/scripts/dataset_preparation/prepare_json_file.py @@ -23,8 +23,10 @@ def get_video_info(video_path): fps = info.get("video_fps", 0) duration = num_frames / fps if fps > 0 else 0 - # Extract name - _, _, videos_dir, video_name = str(video_path).split("/") + from pathlib import Path + video_path = Path(video_path) + videos_dir = video_path.parent.name + video_name = video_path.name return { "path": str(video_name), @@ -100,6 +102,7 @@ def prepare_dataset_json(folder_path, # Save to JSON file output_file = folder_path / output_name + print(folder_path,output_file,output_name) with open(output_file, 'w') as f: json.dump(dataset_info, f, indent=2)