diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1627645 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +__pycache__/ +weights/ +local_data/ +save_audio/ +.vscode/ \ No newline at end of file diff --git a/backend_runs/69a0754e42141599e5301cff/audio/avatar/1.pt b/backend_runs/69a0754e42141599e5301cff/audio/avatar/1.pt new file mode 100644 index 0000000..7ad54eb Binary files /dev/null and b/backend_runs/69a0754e42141599e5301cff/audio/avatar/1.pt differ diff --git a/backend_runs/69a0754e42141599e5301cff/audio/avatar/s1.wav b/backend_runs/69a0754e42141599e5301cff/audio/avatar/s1.wav new file mode 100644 index 0000000..e310693 Binary files /dev/null and b/backend_runs/69a0754e42141599e5301cff/audio/avatar/s1.wav differ diff --git a/backend_runs/69a0754e42141599e5301cff/f799126dcf204820adca9bfdf0a35b79.json b/backend_runs/69a0754e42141599e5301cff/f799126dcf204820adca9bfdf0a35b79.json new file mode 100644 index 0000000..52b6f6d --- /dev/null +++ b/backend_runs/69a0754e42141599e5301cff/f799126dcf204820adca9bfdf0a35b79.json @@ -0,0 +1 @@ +{"prompt": "A professional speaks confidently directly to the camera.", "cond_image": "/mnt/c/Users/anwan/OneDrive/Khan/maity/vidLink/output/videos/69a0754e42141599e5301cff/avatar.png", "tts_audio": {"text": "We are ready for liftoff... finally!", "human1_voice": "/mnt/c/Users/anwan/OneDrive/Khan/maity/vidLink/video_generators/multitalk/weights/Kokoro-82M/voices/af_heart.pt"}, "cond_audio": {}} \ No newline at end of file diff --git a/base_tts_template.json b/base_tts_template.json new file mode 100644 index 0000000..2a9ed1e --- /dev/null +++ b/base_tts_template.json @@ -0,0 +1,10 @@ +{ + "prompt": "A confident representative speaks directly to the camera.", + "cond_image": "Input_outputs/input_files/sales_executive/executive.png", + "tts_audio": { + "text": "example", + "human1_voice": "weights/Kokoro-82M/voices/af_heart.pt" + }, + "cond_audio": {} +} + diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..781e72e --- /dev/null +++ b/cli.py @@ -0,0 +1,417 @@ +""" +@file cli.py +@brief Backend-facing CLI wrapper for the MultiTalk generator. +""" + +import argparse +import json +import os +import shutil +import subprocess +import sys +import uuid +from typing import Any, Dict, Tuple +from urllib.parse import urlparse + +import requests + +import config + +# Characters per second for speech duration estimation (matches app/services/eta_service.py) +CHARS_PER_SECOND = 15.0 + + +def _estimate_speech_duration_seconds(speech_text: str) -> float | None: + """ + @brief Estimate video duration in seconds from speech text using character count. + @param speech_text The speech text to estimate duration for. + @return Estimated duration in seconds, or None if speech_text is empty/missing. + """ + if not speech_text or not isinstance(speech_text, str): + return None + + # Estimate duration: characters / characters_per_second + duration = len(speech_text) / CHARS_PER_SECOND + return duration + + +def _run_command_streaming(command: list[str], cwd: str) -> None: + """ + @brief Run a subprocess while streaming stdout/stderr to the current process. + @param command Command list to execute. + @param cwd Working directory for the subprocess. + @throws RuntimeError when the command exits non-zero. + """ + + proc = subprocess.Popen( + command, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + ) + + tail: list[str] = [] + assert proc.stdout is not None + for line in proc.stdout: + sys.stdout.write(line) + sys.stdout.flush() + tail.append(line) + if len(tail) > 200: + tail.pop(0) + + rc = proc.wait() + if rc != 0: + tail_text = "".join(tail).strip() + raise RuntimeError( + f"multitalk generation failed with exit code {rc}. Last output:\n{tail_text}" + ) + + +def _resolve_path(base_dir: str, path_value: str) -> str: + """ + @brief Resolve a path relative to the multitalk repo if needed. + @param base_dir Base directory for relative resolution. + @param path_value Path to resolve. + @return Absolute path for the input value. + """ + + if os.path.isabs(path_value): + return path_value + return os.path.abspath(os.path.join(base_dir, path_value)) + + +def _ensure_kokoro_weights(repo_dir: str) -> None: + """ + @brief Ensure the Kokoro weights are reachable via `weights/Kokoro-82M` from repo_dir. + @details MultiTalk's `generate_multitalk.py` uses a relative repo_id (`weights/Kokoro-82M`), + so we provide a symlink to the absolute directory configured in `config.KOKORO_DIR`. + @param repo_dir MultiTalk repo directory. + """ + + kokoro_dir = getattr(config, "KOKORO_DIR", "") + if not kokoro_dir: + return + + weights_dir = os.path.join(repo_dir, "weights") + os.makedirs(weights_dir, exist_ok=True) + link_path = os.path.join(weights_dir, "Kokoro-82M") + + if os.path.exists(link_path): + return + + try: + os.symlink(kokoro_dir, link_path) + except Exception: + # If symlinks are not permitted, fall back to doing nothing; the generator will error clearly. + pass + + +def _select_avatar_assets(avatar_dir: str) -> Tuple[str, str]: + """ + @brief Select the avatar configuration JSON and image file from a directory. + @details Finds the first .json file (assumed to be base.json or config) and + the first image file (.png, .jpg, .jpeg, or .webp) in the directory. + Files are selected in sorted order, so naming is predictable. + @param avatar_dir Directory containing avatar assets. + @return Tuple of (json_config_path, image_file_path). + @throws RuntimeError if required files are missing. + """ + + if not os.path.isdir(avatar_dir): + raise RuntimeError(f"Avatar directory not found: {avatar_dir}") + + entries = sorted(os.listdir(avatar_dir)) + json_path = "" + image_path = "" + for name in entries: + path = os.path.join(avatar_dir, name) + if os.path.isdir(path): + continue + lower = name.lower() + if lower.endswith(".json") and not json_path: + json_path = path + elif lower.endswith((".png", ".jpg", ".jpeg", ".webp")) and not image_path: + image_path = path + + if not json_path or not image_path: + raise RuntimeError( + f"Avatar directory must contain one JSON config file and one image file (.png/.jpg/.jpeg/.webp): {avatar_dir}" + ) + + return json_path, image_path + +def _resolve_voice_path(preferred_voice: str | None, base_dir: str) -> str: + """ + @brief Resolve a preferred voice identifier to an absolute voice file path. + @details Accepts a full relative path (e.g. ``weights/Kokoro-82M/voices/af_heart.pt``), + a filename (``af_heart.pt``), or just a voice name (``af_heart``). + @param preferred_voice Voice identifier supplied by the caller. May be *None*. + @param base_dir Base directory for relative path resolution (multitalk repo root). + @return Absolute path to the voice ``.pt`` file. + """ + + if not preferred_voice: + return _resolve_path(base_dir, config.TTS_VOICE) + + # Already looks like a path (contains separator or ends with .pt) + if "/" in preferred_voice or preferred_voice.endswith(".pt"): + return _resolve_path(base_dir, preferred_voice) + + # Bare voice name → resolve inside the Kokoro voices directory. + voice_path = f"weights/Kokoro-82M/voices/{preferred_voice}.pt" + return _resolve_path(base_dir, voice_path) + + +def _build_input_from_template( + data: Dict[str, Any], base_dir: str, avatar_image_path: str +) -> Dict[str, Any]: + """ + @brief Build an input payload using the checked-in base TTS template. + @details Used when the caller supplies a ``avatar`` name (found in S3 bucket under ``avatars/``) + The template's ``cond_image``, ``tts_audio.text``, and ``tts_audio.human1_voice`` + are replaced with the caller-provided values. + @param data Raw job data dictionary. + @param base_dir Multitalk repo directory (for path resolution). + @param avatar_image_path Absolute path to the downloaded avatar image. + @return Payload dictionary ready for multitalk ``input_json``. + @throws RuntimeError when required fields are missing. + """ + + template_path = os.path.join(base_dir, "base_tts_template.json") + payload = _load_json(template_path) + + # Replace the avatar image with the downloaded one. + payload["cond_image"] = avatar_image_path + + if "cond_audio" not in payload: + payload["cond_audio"] = {} + + # Speech text is mandatory. + speech_text = data.get("speech_text") + if not speech_text: + raise RuntimeError("speech_text is required for multitalk TTS mode") + + # Resolve the voice file path. + preferred_voice = data.get("preferredVoice") + voice_path = _resolve_voice_path(preferred_voice, base_dir) + + payload["tts_audio"] = { + "text": speech_text, + "human1_voice": voice_path, + } + + return payload + + +def _build_input_payload( + data: Dict[str, Any], base_dir: str, +) -> Dict[str, Any]: + """ + @brief Build the input payload expected by the multitalk generator. + @details Transforms avatar config JSON into the multitalk generator format: + - Extracts prompt and voice from avatar JSON + - Resolves image path to absolute path + - Converts voice name to path: weights/Kokoro-82M/voices/{voice}.pt + - Combines with speech_text from job data into tts_audio + @param data Raw job data from the backend JSON file. + @param base_dir Base directory for resolving paths (repo directory). + @return Payload dictionary ready for multitalk input_json. + @throws RuntimeError when required fields are missing. + """ + + # Extract required fields from avatar config + prompt = data.get("video_prompt") # Fallback to default prompt if not specified in avatar config + if not prompt: + raise RuntimeError(f"Job data must contain 'video_prompt' field: {data}") + + voice = data.get("kokoro_voice") + if not voice: + raise RuntimeError(f"Job data must contain 'kokoro_voice' field: {data}") + + # Get speech text from job data + speech_text = data.get("speech_text") + if not speech_text: + raise RuntimeError(f"Job data must contain 'speech_text' field: {data}") + + avatar_path = data.get("avatar_path") + if not avatar_path: + raise RuntimeError(f"Job data must contain 'avatar_path' field: {data}") + + # Build the payload in the expected format + payload = { + "prompt": prompt, + "cond_image": _resolve_path(base_dir, avatar_path), + "tts_audio": { + "text": speech_text, + "human1_voice": _resolve_path(base_dir, f"weights/Kokoro-82M/voices/{voice}.pt"), + }, + "cond_audio": {}, + } + + return payload + + +def _load_json(path: str) -> Dict[str, Any]: + """ + @brief Load JSON content from disk. + @param path File path to read. + @return Parsed JSON dictionary. + @throws RuntimeError if JSON cannot be read. + """ + + try: + with open(path, "r", encoding="utf-8") as handle: + return json.load(handle) + except Exception as exc: + raise RuntimeError(f"Failed to read JSON from {path}: {exc}") from exc + + +def _write_json(path: str, payload: Dict[str, Any]) -> None: + """ + @brief Write JSON content to disk. + @param path File path to write. + @param payload JSON-serializable dictionary. + """ + + with open(path, "w", encoding="utf-8") as handle: + json.dump(payload, handle) + + +def main() -> None: + """ + @brief CLI entrypoint for backend-triggered multitalk generation. + @throws RuntimeError on invalid input or generation failure. + """ + + parser = argparse.ArgumentParser( + description="Backend wrapper for MultiTalk generation", + ) + parser.add_argument("--job-id", required=True) + parser.add_argument("--output", required=True) + parser.add_argument("--data", required=True) + parser.add_argument("--work-dir", default=None) + args = parser.parse_args() + + repo_dir = os.path.dirname(os.path.abspath(__file__)) + + work_dir = args.work_dir + if work_dir is None: + print('WARNING: no --work-dir specified, using default "backend_runs/{job_id}"') + work_dir = os.path.join(repo_dir, "backend_runs", args.job_id) + os.makedirs(work_dir, exist_ok=True) + + input_json_path = os.path.join(work_dir, f"cond.json") + audio_save_dir = os.path.join(work_dir, "audio") + + data = _load_json(args.data) + payload = _build_input_payload( + data=data, + base_dir=repo_dir, + ) + _write_json(input_json_path, payload) + + ckpt_dir = config.CKPT_DIR + wav2vec_dir = config.WAV2VEC_DIR + if not ckpt_dir or not wav2vec_dir: + raise RuntimeError("Missing CKPT_DIR or WAV2VEC_DIR in config.py") + + ckpt_dir = _resolve_path(repo_dir, ckpt_dir) + wav2vec_dir = _resolve_path(repo_dir, wav2vec_dir) + + _ensure_kokoro_weights(repo_dir) + + # Calculate frame_num and max_frames_num based on video duration + FPS = 25.0 # Video frames per second + speech_text = data.get("speech_text", "") + video_duration_seconds = ( + _estimate_speech_duration_seconds(speech_text) if speech_text else None + ) + + # Choose the largest safe frame_num for this text, with safety margin. + # Keep it in [33, 81] and enforce frame_num = 4n+1. + MIN_FRAMES = 33 # 4*8 + 1 + MAX_FRAMES = 81 # 4*20 + 1 + + if video_duration_seconds is not None: + # Estimated frames for the (TTS) audio, with a small safety margin + frames_estimated = int(video_duration_seconds * FPS) + frames_target = int(frames_estimated * 0.9) # 10% safety margin + + # Clamp into [MIN_FRAMES, MAX_FRAMES] + frames_target = max(MIN_FRAMES, min(MAX_FRAMES, frames_target)) + + # frame_num must be 4n+1 -> round DOWN to nearest 4n+1 <= frames_target + remainder = (frames_target - 1) % 4 + frame_num = frames_target - remainder + if frame_num < MIN_FRAMES: + frame_num = MIN_FRAMES + else: + # Unknown duration: fall back to the most conservative (max) clip length + frame_num = MAX_FRAMES + + # # Calculate max_frames_num for longer videos + mode = "streaming" + # if mode == "clip": + # max_frames_num = frame_num + # else: + # # Streaming mode: calculate frames needed based on video duration + # if video_duration_seconds is not None: + # # Calculate frames needed: duration * fps + # frames_needed = int(video_duration_seconds * FPS) + # # Ensure it's at least frame_num + # max_frames_num = max(frame_num, frames_needed) + # # Round up to next 4n+1 if needed (to match frame_num pattern) + # remainder = (max_frames_num - 1) % 4 + # if remainder != 0: + # max_frames_num = max_frames_num + (4 - remainder) + # else: + # max_frames_num = 1000 # default for streaming when duration unknown + max_frames_num = 2000 + + command = [ + sys.executable, + os.path.join(repo_dir, "generate_multitalk.py"), + "--ckpt_dir", + ckpt_dir, + "--wav2vec_dir", + wav2vec_dir, + "--input_json", + input_json_path, + "--sample_steps", + str(data.get("sample_steps", getattr(config, "SAMPLE_STEPS", 40))), + "--mode", + mode, + "--num_persistent_param_in_dit", + str(data.get("num_persistent_param_in_dit", 0)), + "--audio_mode", + "tts", + "--audio_save_dir", + audio_save_dir, + "--save_file", + os.path.splitext(args.output)[0], + "--frame_num", + str(frame_num), + ] + + # Add max_frames_num argument for streaming mode + if mode == "streaming": + command.extend(["--max_frames_num", str(max_frames_num)]) + + if data.get("use_teacache", True): + command.append("--use_teacache") + + try: + _run_command_streaming(command, cwd=repo_dir) + finally: + try: + shutil.rmtree(work_dir) + # print(f"Skipping cleanup of work_dir: {work_dir}\n command ran \n {command}") + except Exception: + pass + + +if __name__ == "__main__": + main() + diff --git a/config.py b/config.py new file mode 100644 index 0000000..494dd6a --- /dev/null +++ b/config.py @@ -0,0 +1,10 @@ +AVATAR_DIR = '/home/web/partha/vidLink/video_generators/multitalk/Input_outputs/input_files/sales_executive' +CKPT_DIR = './weights/Wan2.1-I2V-14B-480P' +WAV2VEC_DIR = './weights/chinese-wav2vec2-base' +KOKORO_DIR = './weights/Kokoro-82M' +TTS_VOICE = './weights/Kokoro-82M/voices/af_heart.pt' +SAMPLE_STEPS = 30 + +# Time required to generate 1 second of video (in seconds). +# For multitalk: 4 minutes per video second = 240 seconds +TIME_PER_VIDEO_SECOND_SECONDS = 300.0 diff --git a/custom.sh b/custom.sh new file mode 100644 index 0000000..944c0d4 --- /dev/null +++ b/custom.sh @@ -0,0 +1,12 @@ + +export WAN_DISABLE_FLASH_ATTN="1" + +python generate_multitalk.py \ + --ckpt_dir weights/Wan2.1-I2V-14B-480P \ + --wav2vec_dir 'weights/chinese-wav2vec2-base' \ + --input_json /mnt/c/Users/anwan/OneDrive/Khan/maity/vidLink/local_data/avatars/outreach/base.json \ + --sample_steps 20 \ + --mode streaming \ + --num_persistent_param_in_dit 0 \ + --audio_mode tts \ + --audio_save_dir local_data/outreach_mighty diff --git a/generate_multitalk.py b/generate_multitalk.py index 9a3b525..7516d0a 100644 --- a/generate_multitalk.py +++ b/generate_multitalk.py @@ -80,6 +80,12 @@ def _parse_args(): default=81, help="How many frames to be generated in one clip. The number should be 4n+1" ) + parser.add_argument( + "--max_frames_num", + type=int, + default=None, + help="Maximum number of frames to generate (for streaming mode). If not provided, defaults to frame_num for clip mode or 1000 for streaming mode." + ) parser.add_argument( "--ckpt_dir", type=str, @@ -252,9 +258,9 @@ def _parse_args(): return args def custom_init(device, wav2vec): - audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device) + audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True, attn_implementation="eager").to(device) audio_encoder.feature_extractor._freeze_parameters() - wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True) + wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True, attn_implementation="eager") return wav2vec_feature_extractor, audio_encoder def loudness_norm(audio_array, sr=16000, lufs=-23): @@ -613,7 +619,7 @@ def generate(args): audio_guide_scale=args.sample_audio_guide_scale, seed=args.base_seed, offload_model=args.offload_model, - max_frames_num=args.frame_num if args.mode == 'clip' else 1000, + max_frames_num=args.max_frames_num if args.max_frames_num is not None else (args.frame_num if args.mode == 'clip' else 1000), color_correction_strength = args.color_correction_strength, extra_args=args, ) diff --git a/mini.sh b/mini.sh new file mode 100644 index 0000000..25d3aa6 --- /dev/null +++ b/mini.sh @@ -0,0 +1,12 @@ + +export WAN_DISABLE_FLASH_ATTN="1" + +python generate_multitalk.py \ + --ckpt_dir weights/Wan2.1-I2V-14B-480P \ + --wav2vec_dir 'weights/chinese-wav2vec2-base' \ + --input_json /mnt/c/Users/anwan/OneDrive/Khan/maity/vidLink/local_data/avatars/sales_executive/old-base.json \ + --sample_steps 8 \ + --mode streaming \ + --num_persistent_param_in_dit 30 \ + --audio_mode tts \ + --audio_save_dir local_data/sales_test diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 9f11ef7..d060eff 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -1,4 +1,5 @@ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import os import torch import torch.nn as nn from einops import rearrange, repeat @@ -10,17 +11,29 @@ ) import xformers.ops -try: - import flash_attn_interface - FLASH_ATTN_3_AVAILABLE = True -except ModuleNotFoundError: - FLASH_ATTN_3_AVAILABLE = False +_DISABLE_FLASH_ATTN = os.getenv("WAN_DISABLE_FLASH_ATTN", "").strip().lower() in { + "1", + "true", + "yes", + "on", +} -try: - import flash_attn - FLASH_ATTN_2_AVAILABLE = True -except ModuleNotFoundError: +if _DISABLE_FLASH_ATTN: + FLASH_ATTN_3_AVAILABLE = False FLASH_ATTN_2_AVAILABLE = False + print("Flash attention is disabled by WAN_DISABLE_FLASH_ATTN environment variable.") +else: + try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True + except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + + try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True + except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False import warnings @@ -62,6 +75,22 @@ def flash_attention( assert dtype in half_dtypes assert q.device.type == 'cuda' and q.size(-1) <= 256 + if not FLASH_ATTN_3_AVAILABLE and not FLASH_ATTN_2_AVAILABLE: + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.' + ) + attn_mask = None + + q_sdpa = q.transpose(1, 2).to(dtype) + k_sdpa = k.transpose(1, 2).to(dtype) + v_sdpa = v.transpose(1, 2).to(dtype) + + out = torch.nn.functional.scaled_dot_product_attention( + q_sdpa, k_sdpa, v_sdpa, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + return out.transpose(1, 2).contiguous().type(q.dtype) + # params b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype diff --git a/wan/modules/attention_blackwell.py b/wan/modules/attention_blackwell.py new file mode 100644 index 0000000..f0b82dd --- /dev/null +++ b/wan/modules/attention_blackwell.py @@ -0,0 +1,568 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from einops import rearrange, repeat +from ..utils.multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids +from xfuser.core.distributed import ( + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + get_sp_group, +) +import xformers.ops + +try: + import flash_attn_interface + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + +import warnings + +__all__ = [ + 'flash_attention', + 'attention', +] + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + """ + Modified flash_attention with intelligent fallback for older/alternative GPUs. + Prioritizes: Flash Attention 3/2 > xformers > SDPA + """ + # --- FALLBACK: If Flash Attention is missing, use xformers or SDPA --- + if not (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE): + # Try xformers fallback first (better than SDPA) + try: + q_in = q.transpose(1, 2).to(dtype) + k_in = k.transpose(1, 2).to(dtype) + v_in = v.transpose(1, 2).to(dtype) + + if q_scale is not None: + q_in = q_in * q_scale + + # Apply softmax scale to query if provided + if softmax_scale is not None: + # SDPA uses 1/sqrt(dim) scaling by default + default_scale = q_in.shape[-1] ** -0.5 + q_in = q_in * (softmax_scale / default_scale) + + # Use xformers memory efficient attention if available + out = xformers.ops.memory_efficient_attention( + q_in, k_in, v_in, + attn_bias=None, + op=None, + ) + return out.contiguous().type(q.dtype) + except Exception as xf_err: + # Fall back to SDPA if xformers fails + warnings.warn( + f'xformers attention failed ({xf_err}), falling back to SDPA. ' + 'For better performance on RTX 4090/3090, install flash-attn: ' + 'pip install flash-attn' + ) + q_in = q.transpose(1, 2).to(dtype) + k_in = k.transpose(1, 2).to(dtype) + v_in = v.transpose(1, 2).to(dtype) + + if q_scale is not None: + q_in = q_in * q_scale + + if softmax_scale is not None: + default_scale = q_in.shape[-1] ** -0.5 + q_in = q_in * (softmax_scale / default_scale) + + # SDPA has limited support for window_size, so log a warning + if window_size != (-1, -1): + warnings.warn('SDPA fallback does not support sliding window attention') + + out = torch.nn.functional.scaled_dot_product_attention( + q_in, k_in, v_in, + dropout_p=dropout_p, + is_causal=causal + ) + return out.transpose(1, 2).contiguous().type(q.dtype) + # ---------------------------------------------------------------- + + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == 'cuda' and q.size(-1) <= 256 + + # params + b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + # preprocess query + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor( + [lq] * b, dtype=torch.int32).to( + device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + # preprocess key, value + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor( + [lk] * b, dtype=torch.int32).to( + device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn( + 'Flash attention 3 is not available, use flash attention 2 instead.' + ) + + # apply attention + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + # Note: dropout_p, window_size are not supported in FA3 now. + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=lq, + max_seqlen_k=lk, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic)[0].unflatten(0, (b, lq)) + else: + assert FLASH_ATTN_2_AVAILABLE + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( + 0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=lq, + max_seqlen_k=lk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic).unflatten(0, (b, lq)) + + # output + return x.type(out_dtype) + +# def flash_attention( +# q, +# k, +# v, +# q_lens=None, +# k_lens=None, +# dropout_p=0., +# softmax_scale=None, +# q_scale=None, +# causal=False, +# window_size=(-1, -1), +# deterministic=False, +# dtype=torch.bfloat16, +# version=None, +# ): +# """ +# q: [B, Lq, Nq, C1]. +# k: [B, Lk, Nk, C1]. +# v: [B, Lk, Nk, C2]. Nq must be divisible by Nk. +# q_lens: [B]. +# k_lens: [B]. +# dropout_p: float. Dropout probability. +# softmax_scale: float. The scaling of QK^T before applying softmax. +# causal: bool. Whether to apply causal attention mask. +# window_size: (left right). If not (-1, -1), apply sliding window local attention. +# deterministic: bool. If True, slightly slower and uses more memory. +# dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16. +# """ +# half_dtypes = (torch.float16, torch.bfloat16) +# assert dtype in half_dtypes +# assert q.device.type == 'cuda' and q.size(-1) <= 256 + +# # params +# b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + +# def half(x): +# return x if x.dtype in half_dtypes else x.to(dtype) + +# # preprocess query +# if q_lens is None: +# q = half(q.flatten(0, 1)) +# q_lens = torch.tensor( +# [lq] * b, dtype=torch.int32).to( +# device=q.device, non_blocking=True) +# else: +# q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + +# # preprocess key, value +# if k_lens is None: +# k = half(k.flatten(0, 1)) +# v = half(v.flatten(0, 1)) +# k_lens = torch.tensor( +# [lk] * b, dtype=torch.int32).to( +# device=k.device, non_blocking=True) +# else: +# k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) +# v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + +# q = q.to(v.dtype) +# k = k.to(v.dtype) + +# if q_scale is not None: +# q = q * q_scale + +# if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: +# warnings.warn( +# 'Flash attention 3 is not available, use flash attention 2 instead.' +# ) + +# # apply attention +# if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: +# # Note: dropout_p, window_size are not supported in FA3 now. +# x = flash_attn_interface.flash_attn_varlen_func( +# q=q, +# k=k, +# v=v, +# cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( +# 0, dtype=torch.int32).to(q.device, non_blocking=True), +# cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( +# 0, dtype=torch.int32).to(q.device, non_blocking=True), +# seqused_q=None, +# seqused_k=None, +# max_seqlen_q=lq, +# max_seqlen_k=lk, +# softmax_scale=softmax_scale, +# causal=causal, +# deterministic=deterministic)[0].unflatten(0, (b, lq)) +# else: +# assert FLASH_ATTN_2_AVAILABLE +# x = flash_attn.flash_attn_varlen_func( +# q=q, +# k=k, +# v=v, +# cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum( +# 0, dtype=torch.int32).to(q.device, non_blocking=True), +# cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum( +# 0, dtype=torch.int32).to(q.device, non_blocking=True), +# max_seqlen_q=lq, +# max_seqlen_k=lk, +# dropout_p=dropout_p, +# softmax_scale=softmax_scale, +# causal=causal, +# window_size=window_size, +# deterministic=deterministic).unflatten(0, (b, lq)) + +# # output +# return x.type(out_dtype) + + +def attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0., + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + fa_version=None, +): + if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: + return flash_attention( + q=q, + k=k, + v=v, + q_lens=q_lens, + k_lens=k_lens, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + q_scale=q_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + dtype=dtype, + version=fa_version, + ) + else: + # Fallback: use xformers first, then SDPA + try: + q_t = q.transpose(1, 2).to(dtype) + k_t = k.transpose(1, 2).to(dtype) + v_t = v.transpose(1, 2).to(dtype) + + if q_scale is not None: + q_t = q_t * q_scale + + out = xformers.ops.memory_efficient_attention( + q_t, k_t, v_t, + attn_bias=None, + op=None, + ) + return out.contiguous().type(q.dtype) + except Exception as xf_err: + # Final fallback to SDPA + warnings.warn( + f'xformers attention failed ({xf_err}), falling back to SDPA. ' + 'For better performance on RTX 4090/3090, install flash-attn: ' + 'pip install flash-attn' + ) + if q_lens is not None or k_lens is not None: + warnings.warn( + 'Padding mask is disabled with SDPA fallback. ' + 'This can significantly impact performance.' + ) + + attn_mask = None + q_t = q.transpose(1, 2).to(dtype) + k_t = k.transpose(1, 2).to(dtype) + v_t = v.transpose(1, 2).to(dtype) + + out = torch.nn.functional.scaled_dot_product_attention( + q_t, k_t, v_t, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p) + + out = out.transpose(1, 2).contiguous() + return out + + +class SingleStreamAttention(nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + qk_norm: bool, + norm_layer: nn.Module, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + eps: float = 1e-6, + ) -> None: + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + self.qk_norm = qk_norm + + self.q_linear = nn.Linear(dim, dim, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity() + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias) + + self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: + + N_t, N_h, N_w = shape + if not enable_sp: + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # get q for hidden_state + B, N, C = x.shape + q = self.q_linear(x) + q_shape = (B, N, self.num_heads, self.head_dim) + q = q.view(q_shape).permute((0, 2, 1, 3)) + + if self.qk_norm: + q = self.q_norm(q) + + # get kv from encoder_hidden_states + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) + encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) + encoder_k, encoder_v = encoder_kv.unbind(0) + + if self.qk_norm: + encoder_k = self.add_k_norm(encoder_k) + + + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + if enable_sp: + # context parallel + sp_size = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + visual_seqlen, _ = split_token_counts_and_frame_ids(N_t, N_h * N_w, sp_size, sp_rank) + assert kv_seq is not None, f"kv_seq should not be None." + attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(visual_seqlen, kv_seq) + else: + attn_bias = None + x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,) + x = rearrange(x, "B M H K -> B H M K") + + # linear transform + x_output_shape = (B, N, C) + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + + if not enable_sp: + # reshape x to origin shape + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + + return x + +class SingleStreamMutiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + qk_norm: bool, + norm_layer: nn.Module, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + eps: float = 1e-6, + class_range: int = 24, + class_interval: int = 4, + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + norm_layer=norm_layer, + attn_drop=attn_drop, + proj_drop=proj_drop, + eps=eps, + ) + self.class_interval = class_interval + self.class_range = class_range + self.rope_h1 = (0, self.class_interval) + self.rope_h2 = (self.class_range - self.class_interval, self.class_range) + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward(self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None, + human_num=None) -> torch.Tensor: + + encoder_hidden_states = encoder_hidden_states.squeeze(0) + if human_num == 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, _, _ = shape + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # get q for hidden_state + B, N, C = x.shape + q = self.q_linear(x) + q_shape = (B, N, self.num_heads, self.head_dim) + q = q.view(q_shape).permute((0, 2, 1, 3)) + + if self.qk_norm: + q = self.q_norm(q) + + + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1])) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1])) + back = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device) + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N + + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) + encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) + encoder_k, encoder_v = encoder_kv.unbind(0) + + if self.qk_norm: + encoder_k = self.add_k_norm(encoder_k) + + + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device) + per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2 + per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 + encoder_pos = torch.concat([per_frame]*N_t, dim=0) + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,) + x = rearrange(x, "B M H K -> B H M K") + + # linear transform + x_output_shape = (B, N, C) + x = x.transpose(1, 2) + x = x.reshape(x_output_shape) + x = self.proj(x) + x = self.proj_drop(x) + + # reshape x to origin shape + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + + return x \ No newline at end of file diff --git a/wan/multitalk.py b/wan/multitalk.py index 2685c42..6130e3c 100644 --- a/wan/multitalk.py +++ b/wan/multitalk.py @@ -440,7 +440,7 @@ def generate(self, continue full_audio_embs.append(full_audio_emb) - assert len(full_audio_embs) == HUMAN_NUMBER, f"Aduio file not exists or length not satisfies frame nums." + assert len(full_audio_embs) == HUMAN_NUMBER, f"Aduio file {audio_embedding_paths} not exists or length not satisfies frame nums." # preprocess text embedding if n_prompt == "":