diff --git a/app.py b/app.py index 8bd958f..27c45b6 100644 --- a/app.py +++ b/app.py @@ -42,7 +42,6 @@ for video_name in sorted(video_label_dict): video_stem = Path(video_label_dict[video_name]) servicer.update_video(video_stem, video_stem.with_suffix('.json'), - video_path=video_stem.with_suffix('.mp4'), name=video_name) for audio_name in sorted(audio_label_dict): diff --git a/config/nota_wav2lip.yaml b/config/nota_wav2lip.yaml index 58eda29..00ef05f 100644 --- a/config/nota_wav2lip.yaml +++ b/config/nota_wav2lip.yaml @@ -6,10 +6,8 @@ inference: w: 224 model: wav2lip: - cls: nota_wav2lip.models.Wav2Lip checkpoint: "checkpoints/lrs3-wav2lip.pth" nota_wav2lip: - cls: nota_wav2lip.models.NotaWav2Lip checkpoint: "checkpoints/lrs3-nota-wav2lip.pth" audio: diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..3906d2a --- /dev/null +++ b/inference.py @@ -0,0 +1,83 @@ +import os +import subprocess +from pathlib import Path +import argparse + +from config import hparams as hp +from nota_wav2lip import Wav2LipModelComparisonDemo + + +LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None) +LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None) + +if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None: + subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True) +if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None: + subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True) + +def parse_args(): + + parser = argparse.ArgumentParser(description="NotaWav2Lip: Inference snippet for your own video and audio pair") + + parser.add_argument( + '-a', + '--audio-input', + type=str, + required=True, + help="Path of the audio file" + ) + + parser.add_argument( + '-v', + '--video-frame-input', + type=str, + required=True, + help="Input directory with face image sequence. We recommend to extract the face image sequence with `preprocess.py`." + ) + + parser.add_argument( + '-b', + '--bbox-input', + type=str, + help="Path of the file with bbox coordinates. We recommend to extract the json file with `preprocess.py`." + "If None, it pretends that the json file is located at the same directory with face images: {VIDEO_FRAME_INPUT}.with_suffix('.json')." + ) + + parser.add_argument( + '-m', + '--model', + choices=['wav2lip', 'nota_wav2lip'], + default='nota_wav2ilp', + help="Model for generating talking video. Defaults: wav2lip" + ) + + parser.add_argument( + '-o', + '--output-dir', + type=str, + default="result", + help="Output directory to save the result. Defaults: result" + ) + + parser.add_argument( + '-d', + '--device', + choices=['cpu', 'cuda'], + default='cpu', + help="Device setting for model inference. Defaults: cpu" + ) + + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + bbox_input = args.bbox_input if args.bbox_input is not None \ + else Path(args.video_frame_input).with_suffix('.json') + + servicer = Wav2LipModelComparisonDemo(device=args.device, result_dir=args.output_dir, model_list=args.model) + servicer.update_audio(args.audio_input, name='a0') + servicer.update_video(args.video_frame_input, bbox_input, name='v0') + + servicer.save_as_video('a0', 'v0', args.model) \ No newline at end of file diff --git a/inference.sh b/inference.sh new file mode 100644 index 0000000..d7c9ff0 --- /dev/null +++ b/inference.sh @@ -0,0 +1,6 @@ +python inference.py\ + -a "sample/1673_orig.wav"\ + -v "sample_video_lrs3/EV3OmxrowWE-00003"\ + -m "nota_wav2lip"\ + -o "result"\ + --device cpu \ No newline at end of file diff --git a/nota_wav2lip/demo.py b/nota_wav2lip/demo.py index 76de8e7..ce26976 100644 --- a/nota_wav2lip/demo.py +++ b/nota_wav2lip/demo.py @@ -3,7 +3,7 @@ import subprocess import time from pathlib import Path -from typing import Dict, Iterator, List, Literal +from typing import Dict, Iterator, List, Literal, Optional, Union import cv2 import numpy as np @@ -11,12 +11,15 @@ from config import hparams as hp from nota_wav2lip.inference import Wav2LipInferenceImpl from nota_wav2lip.video import AudioSlicer, VideoSlicer +from nota_wav2lip.util import FFMPEG_LOGGING_MODE class Wav2LipModelComparisonDemo: - def __init__(self, device='cpu', result_dir='./temp', model_list: List[str]=None): + def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None): if model_list is None: - model_list = ['wav2lip', 'nota_wav2lip'] + model_list: List[str] = ['wav2lip', 'nota_wav2lip'] + if isinstance(model_list, str) and len(model_list) != 0: + model_list: List[str] = [model_list] super().__init__() self.video_dict: Dict[str, VideoSlicer] = {} self.audio_dict: Dict[str, AudioSlicer] = {} @@ -25,7 +28,7 @@ def __init__(self, device='cpu', result_dir='./temp', model_list: List[str]=None for model_name in model_list: assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}" self.model_zoo[model_name] = Wav2LipInferenceImpl( - hp_inference_model=hp.inference.model[model_name], device=device + model_name, hp_inference_model=hp.inference.model[model_name], device=device ) self._params_zoo: Dict[str, str] = { @@ -56,16 +59,16 @@ def update_audio(self, audio_path, name=None): {_name: AudioSlicer(audio_path)} ) - def update_video(self, frame_dir_path, bbox_path, video_path=None, name=None): + def update_video(self, frame_dir_path, bbox_path, name=None): _name = name if name is not None else Path(frame_dir_path).stem self.video_dict.update( - {_name: VideoSlicer(frame_dir_path, bbox_path, video_path=video_path)} + {_name: VideoSlicer(frame_dir_path, bbox_path)} ) def save_as_video(self, audio_name, video_name, model_type): - output_video_path = self.result_dir / 'original_voice.mp4' - frame_only_video_path = self.result_dir / 'original.mp4' + output_video_path = self.result_dir / 'generated_with_audio.mp4' + frame_only_video_path = self.result_dir / 'generated.mp4' audio_path = self.audio_dict[audio_name].audio_path out = cv2.VideoWriter(str(frame_only_video_path), @@ -78,77 +81,11 @@ def save_as_video(self, audio_name, video_name, model_type): inference_time = time.time() - start out.release() - command = f"ffmpeg -hide_banner -loglevel error -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}" + command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}" subprocess.call(command, shell=platform.system() != 'Windows') # The number of frames of generated video video_frames_num = len(self.audio_dict[audio_name]) inference_fps = video_frames_num / inference_time - return output_video_path, inference_time, inference_fps - - -def get_parsed_args(): - - parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models') - - parser.add_argument('--resize_factor', default=1, type=int, - help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p') - - parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1], - help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. ' - 'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width') - - parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1], - help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.' - 'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).') - - args = parser.parse_args() - - return args - - -def main(): - get_parsed_args() - - demo_generator = Wav2LipModelComparisonDemo() - demo_generator.update_audio("sample/1673_orig.wav", name="1673") - demo_generator.update_audio("sample/4598_orig.wav", name="4598") - demo_generator.update_video("sample/2145_orig", "sample/2145_orig.json", name="2145") - demo_generator.update_video("sample/2942_orig", "sample/2942_orig.json", name="2942") - - processed_time = [] - for _i in range(5): - start = time.time() - out = cv2.VideoWriter('temp/original.mp4', - cv2.VideoWriter_fourcc(*'mp4v'), - hp.face.video_fps, - (hp.inference.frame.w, hp.inference.frame.h)) - for frame in demo_generator.infer(audio_name="4598", video_name="2145", model_type="original"): - out.write(frame) - out.release() - processed_time.append(time.time() - start) - - command = f"ffmpeg -hide_banner -loglevel error -y -i {'sample/4598_orig.wav'} -i {'temp/original.mp4'} -strict -2 -q:v 1 {'temp/original_voice.mp4'}" - subprocess.call(command, shell=platform.system() != 'Windows') - print(f"Processed time: {np.mean(processed_time)}") - - processed_time = [] - for _i in range(5): - start = time.time() - out = cv2.VideoWriter('temp/compressed.mp4', - cv2.VideoWriter_fourcc(*'mp4v'), - hp.face.video_fps, - (hp.inference.frame.w, hp.inference.frame.h)) - for frame in demo_generator.infer(audio_name="4598", video_name="2145", model_type="compressed"): - out.write(frame) - out.release() - processed_time.append(time.time() - start) - - command = f"ffmpeg -hide_banner -loglevel error -y -i {'sample/4598_orig.wav'} -i {'temp/compressed.mp4'} -strict -2 -q:v 1 {'temp/compressed_voice.mp4'}" - subprocess.call(command, shell=platform.system() != 'Windows') - print(f"Processed time: {np.mean(processed_time)}") - - -if __name__ == '__main__': - main() + return output_video_path, inference_time, inference_fps \ No newline at end of file diff --git a/nota_wav2lip/inference.py b/nota_wav2lip/inference.py index 282f630..4aa1242 100644 --- a/nota_wav2lip/inference.py +++ b/nota_wav2lip/inference.py @@ -12,8 +12,9 @@ class Wav2LipInferenceImpl: - def __init__(self, hp_inference_model: DictConfig, device='cpu'): + def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'): self.model: nn.Module = load_model( + model_name, device=device, **hp_inference_model ) diff --git a/nota_wav2lip/models/util.py b/nota_wav2lip/models/util.py index f51405c..1c3d8c2 100644 --- a/nota_wav2lip/models/util.py +++ b/nota_wav2lip/models/util.py @@ -1,10 +1,13 @@ -import importlib -from typing import Type, Union +from typing import Type, Dict import torch -from nota_wav2lip.models import Wav2LipBase +from nota_wav2lip.models import Wav2LipBase, Wav2Lip, NotaWav2Lip +MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = { + 'wav2lip': Wav2Lip, + 'nota_wav2lip': NotaWav2Lip +} def _load(checkpoint_path, device): assert device in ['cpu', 'cuda'] @@ -14,14 +17,10 @@ def _load(checkpoint_path, device): return torch.load(checkpoint_path) return torch.load(checkpoint_path, map_location=lambda storage, _: storage) -def load_model(cls: Union[str, Type[Wav2LipBase]], checkpoint, device, **kwargs) -> Wav2LipBase: +def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase: - if isinstance(cls, str): - cls_str_splitted = cls.split('.') - cls_parent = '.'.join(cls_str_splitted[:-1]) - cls_module_name = cls_str_splitted[-1] - cls = getattr(importlib.import_module(cls_parent), cls_module_name) - assert issubclass(cls, Wav2LipBase) + cls = MODEL_REGISTRY[model_name.lower()] + assert issubclass(cls, Wav2LipBase) model = cls(**kwargs) checkpoint = _load(checkpoint, device) diff --git a/nota_wav2lip/preprocess/core.py b/nota_wav2lip/preprocess/core.py index f947043..225f41a 100644 --- a/nota_wav2lip/preprocess/core.py +++ b/nota_wav2lip/preprocess/core.py @@ -9,7 +9,7 @@ from loguru import logger import face_detection -from nota_wav2lip.preprocess.ffmpeg import FFMPEG_LOGGING_MODE +from nota_wav2lip.util import FFMPEG_LOGGING_MODE detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cpu') PADDING = [0, 10, 0, 0] diff --git a/nota_wav2lip/preprocess/lrs3_download.py b/nota_wav2lip/preprocess/lrs3_download.py index c38ebf9..3bfcbd7 100644 --- a/nota_wav2lip/preprocess/lrs3_download.py +++ b/nota_wav2lip/preprocess/lrs3_download.py @@ -9,7 +9,7 @@ import numpy as np from loguru import logger -from nota_wav2lip.preprocess.ffmpeg import FFMPEG_LOGGING_MODE +from nota_wav2lip.util import FFMPEG_LOGGING_MODE class LabelInfo(TypedDict): text: str diff --git a/nota_wav2lip/preprocess/ffmpeg.py b/nota_wav2lip/util.py similarity index 100% rename from nota_wav2lip/preprocess/ffmpeg.py rename to nota_wav2lip/util.py diff --git a/nota_wav2lip/video.py b/nota_wav2lip/video.py index 95a5232..5e9c274 100644 --- a/nota_wav2lip/video.py +++ b/nota_wav2lip/video.py @@ -10,7 +10,7 @@ class VideoSlicer: - def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str], video_path: Union[Path, str]): + def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str]): self.fps = hp.face.video_fps self.frame_dir = frame_dir self.frame_path_list = sorted(Path(self.frame_dir).glob("*.jpg")) @@ -21,11 +21,6 @@ def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str], vid self.bbox: List[List[int]] = [metadata['bbox'][key] for key in sorted(metadata['bbox'].keys())] self.bbox_format = metadata['format'] assert len(self.bbox) == len(self.frame_array_list) - self._video_path = video_path - - @property - def video_path(self): - return self._video_path def __len__(self): return len(self.frame_array_list)