Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

GPU_NUM=1 # 2,4,8
MODEL_PATH="hunyuanvideo-community/HunyuanVideo"
DATASET_PATH="data/crush-smol"
OUTPUT_DIR="data/crush-smol_processed_t2v_hunyuan/"

torchrun --nproc_per_node=$GPU_NUM \
-m fastvideo.pipelines.preprocess.v1_preprocessing_new \
--model_path $MODEL_PATH \
--mode preprocess \
--workload_type t2v \
--preprocess.dataset_type merged \
--preprocess.dataset_path $DATASET_PATH \
--preprocess.dataset_output_dir $OUTPUT_DIR \
--preprocess.preprocess_video_batch_size 2 \
--preprocess.dataloader_num_workers 0 \
--preprocess.max_height 480 \
--preprocess.max_width 832 \
--preprocess.num_frames 77 \
--preprocess.train_fps 16 \
--preprocess.samples_per_file 8 \
--preprocess.flush_frequency 8 \
--preprocess.video_length_tolerance_range 5

8 changes: 7 additions & 1 deletion fastvideo/configs/models/encoders/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig):
class CLIPTextConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(
default_factory=CLIPTextArchConfig)

tokenizer_kwargs: dict = field(
default_factory=lambda: {
"padding": "max_length",
"truncation": True,
"max_length": 77,
"return_tensors": "pt"
})
num_hidden_layers_override: int | None = None
require_post_norm: bool | None = None
prefix: str = "clip"
Expand Down
8 changes: 7 additions & 1 deletion fastvideo/configs/models/encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,11 @@ class LlamaArchConfig(TextEncoderArchConfig):
@dataclass
class LlamaConfig(TextEncoderConfig):
arch_config: TextEncoderArchConfig = field(default_factory=LlamaArchConfig)

tokenizer_kwargs: dict = field(
default_factory=lambda: {
"padding": "max_length",
"truncation": True,
"max_length": 256,
"return_tensors": "pt"
})
prefix: str = "llama"
4 changes: 2 additions & 2 deletions fastvideo/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,14 @@ def forward_native(
cos, sin = cos_sin.chunk(2, dim=-1)

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query = query.reshape(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key = key.reshape(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
Expand Down
4 changes: 2 additions & 2 deletions fastvideo/models/vaes/hunyuanvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states.device,
batch_size=batch_size)
hidden_states = attn(hidden_states,
attention_mask=attention_mask)
attention_mask=attention_mask.unsqueeze(1))
hidden_states = hidden_states.unflatten(
1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)

Expand All @@ -385,7 +385,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states.device,
batch_size=batch_size)
hidden_states = attn(hidden_states,
attention_mask=attention_mask)
attention_mask=attention_mask.unsqueeze(1))
hidden_states = hidden_states.unflatten(
1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)

Expand Down
Empty file.
110 changes: 110 additions & 0 deletions fastvideo/pipelines/preprocess/hunyuan/hunyuan_preprocess_pipelines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase
from fastvideo.pipelines.preprocess.preprocess_stages import (
TextTransformStage, VideoTransformStage)
from fastvideo.pipelines.stages import (EncodingStage, ImageEncodingStage,
TextEncodingStage)
from fastvideo.pipelines.stages.image_encoding import ImageVAEEncodingStage


class PreprocessPipelineI2V(ComposedPipelineBase):
_required_config_modules = [
"image_encoder", "image_processor", "text_encoder", "tokenizer",
"text_encoder_2", "tokenizer_2", "vae"
]

def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
assert fastvideo_args.preprocess_config is not None
self.add_stage(stage_name="text_transform_stage",
stage=TextTransformStage(
cfg_uncondition_drop_rate=fastvideo_args.
preprocess_config.training_cfg_rate,
seed=fastvideo_args.preprocess_config.seed,
))
self.add_stage(stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
))
self.add_stage(
stage_name="video_transform_stage",
stage=VideoTransformStage(
train_fps=fastvideo_args.preprocess_config.train_fps,
num_frames=fastvideo_args.preprocess_config.num_frames,
max_height=fastvideo_args.preprocess_config.max_height,
max_width=fastvideo_args.preprocess_config.max_width,
do_temporal_sample=fastvideo_args.preprocess_config.
do_temporal_sample,
))
if (self.get_module("image_encoder") is not None
and self.get_module("image_processor") is not None):
self.add_stage(
stage_name="image_encoding_stage",
stage=ImageEncodingStage(
image_encoder=self.get_module("image_encoder"),
image_processor=self.get_module("image_processor"),
))
self.add_stage(stage_name="image_vae_encoding_stage",
stage=ImageVAEEncodingStage(
vae=self.get_module("vae"), ))
self.add_stage(stage_name="video_encoding_stage",
stage=EncodingStage(vae=self.get_module("vae"), ))


class PreprocessPipelineT2V(ComposedPipelineBase):
_required_config_modules = [
"text_encoder", "tokenizer", "text_encoder_2", "tokenizer_2", "vae"
]

def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
assert fastvideo_args.preprocess_config is not None
self.add_stage(stage_name="text_transform_stage",
stage=TextTransformStage(
cfg_uncondition_drop_rate=fastvideo_args.
preprocess_config.training_cfg_rate,
seed=fastvideo_args.preprocess_config.seed,
))
# llama_tokenizer_kwargs = {
# "padding": "max_length",
# "truncation": True,
# "max_length": 256,
# "return_tensors": "pt"
# }
# clip_tokenizer_kwargs = {
# "padding": "max_length",
# "truncation": True,
# "max_length": 77,
# "return_tensors": "pt"
# }
# if len(fastvideo_args.pipeline_config.text_encoder_configs) >= 2:
# fastvideo_args.pipeline_config.text_encoder_configs[0].tokenizer_kwargs = llama_tokenizer_kwargs
# fastvideo_args.pipeline_config.text_encoder_configs[1].tokenizer_kwargs = clip_tokenizer_kwargs
text_encoders = [
self.get_module("text_encoder"),
self.get_module("text_encoder_2")
]
tokenizers = [
self.get_module("tokenizer"),
self.get_module("tokenizer_2")
]

self.add_stage(stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=text_encoders,
tokenizers=tokenizers,
))
self.add_stage(
stage_name="video_transform_stage",
stage=VideoTransformStage(
train_fps=fastvideo_args.preprocess_config.train_fps,
num_frames=fastvideo_args.preprocess_config.num_frames,
max_height=fastvideo_args.preprocess_config.max_height,
max_width=fastvideo_args.preprocess_config.max_width,
do_temporal_sample=fastvideo_args.preprocess_config.
do_temporal_sample,
))
self.add_stage(stage_name="video_encoding_stage",
stage=EncodingStage(vae=self.get_module("vae"), ))


EntryClass = [PreprocessPipelineI2V, PreprocessPipelineT2V]
16 changes: 3 additions & 13 deletions fastvideo/workflow/preprocess/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from datasets import Dataset, Video, load_dataset

from fastvideo.configs.configs import DatasetType, PreprocessConfig
from fastvideo.distributed.parallel_state import get_world_rank, get_world_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some issues with your rebase. You should align your code with our implementation(accept both) instead of only use your code(accept current). Please revert this.

from fastvideo.logger import init_logger
from fastvideo.pipelines.pipeline_batch_info import PreprocessBatch

Expand Down Expand Up @@ -79,8 +78,8 @@ def __call__(self, batch: dict[str, Any]) -> bool:

def _validate_data_type(self, batch: dict[str, Any]) -> bool:
"""Validate basic validity of data items"""
return not (batch["caption"] is None or batch["caption"] == ""
or batch["fps"] is None or batch["fps"] <= 0
return not (batch["caption"] is None or batch["caption"] == "" or "fps"
not in batch or batch["fps"] is None or batch["fps"] <= 0
or batch["num_frames"] is None or batch["num_frames"] <= 0)

def _validate_resolution(self, batch: dict[str, Any]) -> bool:
Expand Down Expand Up @@ -400,13 +399,9 @@ def _default_file_writer_fn(self, args_tuple: tuple) -> int:
return written_count


def build_dataset(preprocess_config: PreprocessConfig, split: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this

validator: Callable[[dict[str, Any]], bool]) -> Dataset:
def build_dataset(preprocess_config: PreprocessConfig, split: str) -> Dataset:
if preprocess_config.dataset_type == DatasetType.HF:
dataset = load_dataset(preprocess_config.dataset_path, split=split)
dataset = dataset.filter(validator)
dataset = dataset.shard(num_shards=get_world_size(),
index=get_world_rank())
elif preprocess_config.dataset_type == DatasetType.MERGED:
metadata_json_path = os.path.join(preprocess_config.dataset_path,
"videos2caption.json")
Expand All @@ -420,11 +415,6 @@ def build_dataset(preprocess_config: PreprocessConfig, split: str,
dataset = dataset.rename_column("cap", "caption")
if "path" in column_names:
dataset = dataset.rename_column("path", "name")

dataset = dataset.filter(validator)
dataset = dataset.shard(num_shards=get_world_size(),
index=get_world_rank())

# add video column
def add_video_column(item: dict[str, Any]) -> dict[str, Any]:
item["video"] = os.path.join(video_folder, item["name"])
Expand Down
10 changes: 5 additions & 5 deletions fastvideo/workflow/preprocess/preprocess_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def register_components(self) -> None:
self.add_component("raw_data_validator", raw_data_validator)

# training dataset
training_dataset = build_dataset(preprocess_config,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this

split="train",
validator=raw_data_validator)
training_dataset = build_dataset(preprocess_config, split="train")
# set load_from_cache_file to False to check filter stats
training_dataset = training_dataset.filter(raw_data_validator)
# we do not use collate_fn here because we use iterable-style Dataset
# and want to keep the original type of the dataset
training_dataloader = DataLoader(
Expand All @@ -60,8 +60,8 @@ def register_components(self) -> None:
# try to load validation dataset if it exists
try:
validation_dataset = build_dataset(preprocess_config,
split="validation",
validator=raw_data_validator)
split="validation")
validation_dataset = validation_dataset.filter(raw_data_validator)
validation_dataloader = DataLoader(
validation_dataset,
batch_size=preprocess_config.preprocess_video_batch_size,
Expand Down
7 changes: 5 additions & 2 deletions scripts/dataset_preparation/prepare_json_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ def get_video_info(video_path):
fps = info.get("video_fps", 0)
duration = num_frames / fps if fps > 0 else 0

# Extract name
_, _, videos_dir, video_name = str(video_path).split("/")
from pathlib import Path
video_path = Path(video_path)
videos_dir = video_path.parent.name
video_name = video_path.name

return {
"path": str(video_name),
Expand Down Expand Up @@ -100,6 +102,7 @@ def prepare_dataset_json(folder_path,

# Save to JSON file
output_file = folder_path / output_name
print(folder_path,output_file,output_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this

with open(output_file, 'w') as f:
json.dump(dataset_info, f, indent=2)

Expand Down