-
Couldn't load subscription status.
- Fork 189
[Preprocess] [feat] Support HunyuanVideo Model #754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
93e8e2b
766b931
d30879a
3b39366
94b8767
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| #!/bin/bash | ||
|
|
||
| GPU_NUM=1 # 2,4,8 | ||
| MODEL_PATH="hunyuanvideo-community/HunyuanVideo" | ||
| DATASET_PATH="data/crush-smol" | ||
| OUTPUT_DIR="data/crush-smol_processed_t2v_hunyuan/" | ||
|
|
||
| torchrun --nproc_per_node=$GPU_NUM \ | ||
| -m fastvideo.pipelines.preprocess.v1_preprocessing_new \ | ||
| --model_path $MODEL_PATH \ | ||
| --mode preprocess \ | ||
| --workload_type t2v \ | ||
| --preprocess.dataset_type merged \ | ||
| --preprocess.dataset_path $DATASET_PATH \ | ||
| --preprocess.dataset_output_dir $OUTPUT_DIR \ | ||
| --preprocess.preprocess_video_batch_size 2 \ | ||
| --preprocess.dataloader_num_workers 0 \ | ||
| --preprocess.max_height 480 \ | ||
| --preprocess.max_width 832 \ | ||
| --preprocess.num_frames 77 \ | ||
| --preprocess.train_fps 16 \ | ||
| --preprocess.samples_per_file 8 \ | ||
| --preprocess.flush_frequency 8 \ | ||
| --preprocess.video_length_tolerance_range 5 | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| from fastvideo.fastvideo_args import FastVideoArgs | ||
| from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase | ||
| from fastvideo.pipelines.preprocess.preprocess_stages import ( | ||
| TextTransformStage, VideoTransformStage) | ||
| from fastvideo.pipelines.stages import (EncodingStage, ImageEncodingStage, | ||
| TextEncodingStage) | ||
| from fastvideo.pipelines.stages.image_encoding import ImageVAEEncodingStage | ||
|
|
||
|
|
||
| class PreprocessPipelineI2V(ComposedPipelineBase): | ||
| _required_config_modules = [ | ||
| "image_encoder", "image_processor", "text_encoder", "tokenizer", | ||
| "text_encoder_2", "tokenizer_2", "vae" | ||
| ] | ||
|
|
||
| def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): | ||
| assert fastvideo_args.preprocess_config is not None | ||
| self.add_stage(stage_name="text_transform_stage", | ||
| stage=TextTransformStage( | ||
| cfg_uncondition_drop_rate=fastvideo_args. | ||
| preprocess_config.training_cfg_rate, | ||
| seed=fastvideo_args.preprocess_config.seed, | ||
| )) | ||
| self.add_stage(stage_name="prompt_encoding_stage", | ||
| stage=TextEncodingStage( | ||
| text_encoders=[self.get_module("text_encoder")], | ||
| tokenizers=[self.get_module("tokenizer")], | ||
| )) | ||
| self.add_stage( | ||
| stage_name="video_transform_stage", | ||
| stage=VideoTransformStage( | ||
| train_fps=fastvideo_args.preprocess_config.train_fps, | ||
| num_frames=fastvideo_args.preprocess_config.num_frames, | ||
| max_height=fastvideo_args.preprocess_config.max_height, | ||
| max_width=fastvideo_args.preprocess_config.max_width, | ||
| do_temporal_sample=fastvideo_args.preprocess_config. | ||
| do_temporal_sample, | ||
| )) | ||
| if (self.get_module("image_encoder") is not None | ||
| and self.get_module("image_processor") is not None): | ||
| self.add_stage( | ||
| stage_name="image_encoding_stage", | ||
| stage=ImageEncodingStage( | ||
| image_encoder=self.get_module("image_encoder"), | ||
| image_processor=self.get_module("image_processor"), | ||
| )) | ||
| self.add_stage(stage_name="image_vae_encoding_stage", | ||
| stage=ImageVAEEncodingStage( | ||
| vae=self.get_module("vae"), )) | ||
| self.add_stage(stage_name="video_encoding_stage", | ||
| stage=EncodingStage(vae=self.get_module("vae"), )) | ||
|
|
||
|
|
||
| class PreprocessPipelineT2V(ComposedPipelineBase): | ||
| _required_config_modules = [ | ||
| "text_encoder", "tokenizer", "text_encoder_2", "tokenizer_2", "vae" | ||
| ] | ||
|
|
||
| def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): | ||
| assert fastvideo_args.preprocess_config is not None | ||
| self.add_stage(stage_name="text_transform_stage", | ||
| stage=TextTransformStage( | ||
| cfg_uncondition_drop_rate=fastvideo_args. | ||
| preprocess_config.training_cfg_rate, | ||
| seed=fastvideo_args.preprocess_config.seed, | ||
| )) | ||
| # llama_tokenizer_kwargs = { | ||
| # "padding": "max_length", | ||
| # "truncation": True, | ||
| # "max_length": 256, | ||
| # "return_tensors": "pt" | ||
| # } | ||
| # clip_tokenizer_kwargs = { | ||
| # "padding": "max_length", | ||
| # "truncation": True, | ||
| # "max_length": 77, | ||
| # "return_tensors": "pt" | ||
| # } | ||
| # if len(fastvideo_args.pipeline_config.text_encoder_configs) >= 2: | ||
| # fastvideo_args.pipeline_config.text_encoder_configs[0].tokenizer_kwargs = llama_tokenizer_kwargs | ||
| # fastvideo_args.pipeline_config.text_encoder_configs[1].tokenizer_kwargs = clip_tokenizer_kwargs | ||
| text_encoders = [ | ||
| self.get_module("text_encoder"), | ||
| self.get_module("text_encoder_2") | ||
| ] | ||
| tokenizers = [ | ||
| self.get_module("tokenizer"), | ||
| self.get_module("tokenizer_2") | ||
| ] | ||
|
|
||
| self.add_stage(stage_name="prompt_encoding_stage", | ||
| stage=TextEncodingStage( | ||
| text_encoders=text_encoders, | ||
| tokenizers=tokenizers, | ||
| )) | ||
| self.add_stage( | ||
| stage_name="video_transform_stage", | ||
| stage=VideoTransformStage( | ||
| train_fps=fastvideo_args.preprocess_config.train_fps, | ||
| num_frames=fastvideo_args.preprocess_config.num_frames, | ||
| max_height=fastvideo_args.preprocess_config.max_height, | ||
| max_width=fastvideo_args.preprocess_config.max_width, | ||
| do_temporal_sample=fastvideo_args.preprocess_config. | ||
| do_temporal_sample, | ||
| )) | ||
| self.add_stage(stage_name="video_encoding_stage", | ||
| stage=EncodingStage(vae=self.get_module("vae"), )) | ||
|
|
||
|
|
||
| EntryClass = [PreprocessPipelineI2V, PreprocessPipelineT2V] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,6 @@ | |
| from datasets import Dataset, Video, load_dataset | ||
|
|
||
| from fastvideo.configs.configs import DatasetType, PreprocessConfig | ||
| from fastvideo.distributed.parallel_state import get_world_rank, get_world_size | ||
| from fastvideo.logger import init_logger | ||
| from fastvideo.pipelines.pipeline_batch_info import PreprocessBatch | ||
|
|
||
|
|
@@ -79,8 +78,8 @@ def __call__(self, batch: dict[str, Any]) -> bool: | |
|
|
||
| def _validate_data_type(self, batch: dict[str, Any]) -> bool: | ||
| """Validate basic validity of data items""" | ||
| return not (batch["caption"] is None or batch["caption"] == "" | ||
| or batch["fps"] is None or batch["fps"] <= 0 | ||
| return not (batch["caption"] is None or batch["caption"] == "" or "fps" | ||
| not in batch or batch["fps"] is None or batch["fps"] <= 0 | ||
| or batch["num_frames"] is None or batch["num_frames"] <= 0) | ||
|
|
||
| def _validate_resolution(self, batch: dict[str, Any]) -> bool: | ||
|
|
@@ -400,13 +399,9 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int: | |
| return written_count | ||
|
|
||
|
|
||
| def build_dataset(preprocess_config: PreprocessConfig, split: str, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert this |
||
| validator: Callable[[dict[str, Any]], bool]) -> Dataset: | ||
| def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset: | ||
| if preprocess_config.dataset_type == DatasetType.HF: | ||
| dataset = load_dataset(preprocess_config.dataset_path, split=split) | ||
| dataset = dataset.filter(validator) | ||
| dataset = dataset.shard(num_shards=get_world_size(), | ||
| index=get_world_rank()) | ||
| elif preprocess_config.dataset_type == DatasetType.MERGED: | ||
| metadata_json_path = os.path.join(preprocess_config.dataset_path, | ||
| "videos2caption.json") | ||
|
|
@@ -420,11 +415,6 @@ def build_dataset(preprocess_config: PreprocessConfig, split: str, | |
| dataset = dataset.rename_column("cap", "caption") | ||
| if "path" in column_names: | ||
| dataset = dataset.rename_column("path", "name") | ||
|
|
||
| dataset = dataset.filter(validator) | ||
| dataset = dataset.shard(num_shards=get_world_size(), | ||
| index=get_world_rank()) | ||
|
|
||
| # add video column | ||
| def add_video_column(item: dict[str, Any]) -> dict[str, Any]: | ||
| item["video"] = os.path.join(video_folder, item["name"]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,9 +44,9 @@ def register_components(self) -> None: | |
| self.add_component("raw_data_validator", raw_data_validator) | ||
|
|
||
| # training dataset | ||
| training_dataset = build_dataset(preprocess_config, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revert this |
||
| split="train", | ||
| validator=raw_data_validator) | ||
| training_dataset = build_dataset(preprocess_config, split="train") | ||
| # set load_from_cache_file to False to check filter stats | ||
| training_dataset = training_dataset.filter(raw_data_validator) | ||
| # we do not use collate_fn here because we use iterable-style Dataset | ||
| # and want to keep the original type of the dataset | ||
| training_dataloader = DataLoader( | ||
|
|
@@ -60,8 +60,8 @@ def register_components(self) -> None: | |
| # try to load validation dataset if it exists | ||
| try: | ||
| validation_dataset = build_dataset(preprocess_config, | ||
| split="validation", | ||
| validator=raw_data_validator) | ||
| split="validation") | ||
| validation_dataset = validation_dataset.filter(raw_data_validator) | ||
| validation_dataloader = DataLoader( | ||
| validation_dataset, | ||
| batch_size=preprocess_config.preprocess_video_batch_size, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,8 +23,10 @@ def get_video_info(video_path): | |
| fps = info.get("video_fps", 0) | ||
| duration = num_frames / fps if fps > 0 else 0 | ||
|
|
||
| # Extract name | ||
| _, _, videos_dir, video_name = str(video_path).split("/") | ||
| from pathlib import Path | ||
| video_path = Path(video_path) | ||
| videos_dir = video_path.parent.name | ||
| video_name = video_path.name | ||
|
|
||
| return { | ||
| "path": str(video_name), | ||
|
|
@@ -100,6 +102,7 @@ def prepare_dataset_json(folder_path, | |
|
|
||
| # Save to JSON file | ||
| output_file = folder_path / output_name | ||
| print(folder_path,output_file,output_name) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this |
||
| with open(output_file, 'w') as f: | ||
| json.dump(dataset_info, f, indent=2) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some issues with your rebase. You should align your code with our implementation(accept both) instead of only use your code(accept current). Please revert this.