Skip to content

Commit

Permalink
Merge branch 'inference-pipeline' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Hyoung-Kyu Song committed Feb 21, 2024
2 parents f398685 + 1eb145b commit 0c39038
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 98 deletions.
1 change: 0 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
for video_name in sorted(video_label_dict):
video_stem = Path(video_label_dict[video_name])
servicer.update_video(video_stem, video_stem.with_suffix('.json'),
video_path=video_stem.with_suffix('.mp4'),
name=video_name)

for audio_name in sorted(audio_label_dict):
Expand Down
2 changes: 0 additions & 2 deletions config/nota_wav2lip.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ inference:
w: 224
model:
wav2lip:
cls: nota_wav2lip.models.Wav2Lip
checkpoint: "checkpoints/lrs3-wav2lip.pth"
nota_wav2lip:
cls: nota_wav2lip.models.NotaWav2Lip
checkpoint: "checkpoints/lrs3-nota-wav2lip.pth"

audio:
Expand Down
83 changes: 83 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import subprocess
from pathlib import Path
import argparse

from config import hparams as hp
from nota_wav2lip import Wav2LipModelComparisonDemo


LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)

if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)

def parse_args():

parser = argparse.ArgumentParser(description="NotaWav2Lip: Inference snippet for your own video and audio pair")

parser.add_argument(
'-a',
'--audio-input',
type=str,
required=True,
help="Path of the audio file"
)

parser.add_argument(
'-v',
'--video-frame-input',
type=str,
required=True,
help="Input directory with face image sequence. We recommend to extract the face image sequence with `preprocess.py`."
)

parser.add_argument(
'-b',
'--bbox-input',
type=str,
help="Path of the file with bbox coordinates. We recommend to extract the json file with `preprocess.py`."
"If None, it pretends that the json file is located at the same directory with face images: {VIDEO_FRAME_INPUT}.with_suffix('.json')."
)

parser.add_argument(
'-m',
'--model',
choices=['wav2lip', 'nota_wav2lip'],
default='nota_wav2ilp',
help="Model for generating talking video. Defaults: wav2lip"
)

parser.add_argument(
'-o',
'--output-dir',
type=str,
default="result",
help="Output directory to save the result. Defaults: result"
)

parser.add_argument(
'-d',
'--device',
choices=['cpu', 'cuda'],
default='cpu',
help="Device setting for model inference. Defaults: cpu"
)

args = parser.parse_args()

return args

if __name__ == "__main__":
args = parse_args()
bbox_input = args.bbox_input if args.bbox_input is not None \
else Path(args.video_frame_input).with_suffix('.json')

servicer = Wav2LipModelComparisonDemo(device=args.device, result_dir=args.output_dir, model_list=args.model)
servicer.update_audio(args.audio_input, name='a0')
servicer.update_video(args.video_frame_input, bbox_input, name='v0')

servicer.save_as_video('a0', 'v0', args.model)
6 changes: 6 additions & 0 deletions inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python inference.py\
-a "sample/1673_orig.wav"\
-v "sample_video_lrs3/EV3OmxrowWE-00003"\
-m "nota_wav2lip"\
-o "result"\
--device cpu
89 changes: 13 additions & 76 deletions nota_wav2lip/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
import subprocess
import time
from pathlib import Path
from typing import Dict, Iterator, List, Literal
from typing import Dict, Iterator, List, Literal, Optional, Union

import cv2
import numpy as np

from config import hparams as hp
from nota_wav2lip.inference import Wav2LipInferenceImpl
from nota_wav2lip.video import AudioSlicer, VideoSlicer
from nota_wav2lip.util import FFMPEG_LOGGING_MODE


class Wav2LipModelComparisonDemo:
def __init__(self, device='cpu', result_dir='./temp', model_list: List[str]=None):
def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None):
if model_list is None:
model_list = ['wav2lip', 'nota_wav2lip']
model_list: List[str] = ['wav2lip', 'nota_wav2lip']
if isinstance(model_list, str) and len(model_list) != 0:
model_list: List[str] = [model_list]
super().__init__()
self.video_dict: Dict[str, VideoSlicer] = {}
self.audio_dict: Dict[str, AudioSlicer] = {}
Expand All @@ -25,7 +28,7 @@ def __init__(self, device='cpu', result_dir='./temp', model_list: List[str]=None
for model_name in model_list:
assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}"
self.model_zoo[model_name] = Wav2LipInferenceImpl(
hp_inference_model=hp.inference.model[model_name], device=device
model_name, hp_inference_model=hp.inference.model[model_name], device=device
)

self._params_zoo: Dict[str, str] = {
Expand Down Expand Up @@ -56,16 +59,16 @@ def update_audio(self, audio_path, name=None):
{_name: AudioSlicer(audio_path)}
)

def update_video(self, frame_dir_path, bbox_path, video_path=None, name=None):
def update_video(self, frame_dir_path, bbox_path, name=None):
_name = name if name is not None else Path(frame_dir_path).stem
self.video_dict.update(
{_name: VideoSlicer(frame_dir_path, bbox_path, video_path=video_path)}
{_name: VideoSlicer(frame_dir_path, bbox_path)}
)

def save_as_video(self, audio_name, video_name, model_type):

output_video_path = self.result_dir / 'original_voice.mp4'
frame_only_video_path = self.result_dir / 'original.mp4'
output_video_path = self.result_dir / 'generated_with_audio.mp4'
frame_only_video_path = self.result_dir / 'generated.mp4'
audio_path = self.audio_dict[audio_name].audio_path

out = cv2.VideoWriter(str(frame_only_video_path),
Expand All @@ -78,77 +81,11 @@ def save_as_video(self, audio_name, video_name, model_type):
inference_time = time.time() - start
out.release()

command = f"ffmpeg -hide_banner -loglevel error -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}"
command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}"
subprocess.call(command, shell=platform.system() != 'Windows')

# The number of frames of generated video
video_frames_num = len(self.audio_dict[audio_name])
inference_fps = video_frames_num / inference_time

return output_video_path, inference_time, inference_fps


def get_parsed_args():

parser = argparse.ArgumentParser(description='Inference code to lip-sync videos in the wild using Wav2Lip models')

parser.add_argument('--resize_factor', default=1, type=int,
help='Reduce the resolution by this factor. Sometimes, best results are obtained at 480p or 720p')

parser.add_argument('--crop', nargs='+', type=int, default=[0, -1, 0, -1],
help='Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. '
'Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width')

parser.add_argument('--box', nargs='+', type=int, default=[-1, -1, -1, -1],
help='Specify a constant bounding box for the face. Use only as a last resort if the face is not detected.'
'Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).')

args = parser.parse_args()

return args


def main():
get_parsed_args()

demo_generator = Wav2LipModelComparisonDemo()
demo_generator.update_audio("sample/1673_orig.wav", name="1673")
demo_generator.update_audio("sample/4598_orig.wav", name="4598")
demo_generator.update_video("sample/2145_orig", "sample/2145_orig.json", name="2145")
demo_generator.update_video("sample/2942_orig", "sample/2942_orig.json", name="2942")

processed_time = []
for _i in range(5):
start = time.time()
out = cv2.VideoWriter('temp/original.mp4',
cv2.VideoWriter_fourcc(*'mp4v'),
hp.face.video_fps,
(hp.inference.frame.w, hp.inference.frame.h))
for frame in demo_generator.infer(audio_name="4598", video_name="2145", model_type="original"):
out.write(frame)
out.release()
processed_time.append(time.time() - start)

command = f"ffmpeg -hide_banner -loglevel error -y -i {'sample/4598_orig.wav'} -i {'temp/original.mp4'} -strict -2 -q:v 1 {'temp/original_voice.mp4'}"
subprocess.call(command, shell=platform.system() != 'Windows')
print(f"Processed time: {np.mean(processed_time)}")

processed_time = []
for _i in range(5):
start = time.time()
out = cv2.VideoWriter('temp/compressed.mp4',
cv2.VideoWriter_fourcc(*'mp4v'),
hp.face.video_fps,
(hp.inference.frame.w, hp.inference.frame.h))
for frame in demo_generator.infer(audio_name="4598", video_name="2145", model_type="compressed"):
out.write(frame)
out.release()
processed_time.append(time.time() - start)

command = f"ffmpeg -hide_banner -loglevel error -y -i {'sample/4598_orig.wav'} -i {'temp/compressed.mp4'} -strict -2 -q:v 1 {'temp/compressed_voice.mp4'}"
subprocess.call(command, shell=platform.system() != 'Windows')
print(f"Processed time: {np.mean(processed_time)}")


if __name__ == '__main__':
main()
return output_video_path, inference_time, inference_fps
3 changes: 2 additions & 1 deletion nota_wav2lip/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@


class Wav2LipInferenceImpl:
def __init__(self, hp_inference_model: DictConfig, device='cpu'):
def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'):
self.model: nn.Module = load_model(
model_name,
device=device,
**hp_inference_model
)
Expand Down
19 changes: 9 additions & 10 deletions nota_wav2lip/models/util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import importlib
from typing import Type, Union
from typing import Type, Dict

import torch

from nota_wav2lip.models import Wav2LipBase
from nota_wav2lip.models import Wav2LipBase, Wav2Lip, NotaWav2Lip

MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = {
'wav2lip': Wav2Lip,
'nota_wav2lip': NotaWav2Lip
}

def _load(checkpoint_path, device):
assert device in ['cpu', 'cuda']
Expand All @@ -14,14 +17,10 @@ def _load(checkpoint_path, device):
return torch.load(checkpoint_path)
return torch.load(checkpoint_path, map_location=lambda storage, _: storage)

def load_model(cls: Union[str, Type[Wav2LipBase]], checkpoint, device, **kwargs) -> Wav2LipBase:
def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase:

if isinstance(cls, str):
cls_str_splitted = cls.split('.')
cls_parent = '.'.join(cls_str_splitted[:-1])
cls_module_name = cls_str_splitted[-1]
cls = getattr(importlib.import_module(cls_parent), cls_module_name)
assert issubclass(cls, Wav2LipBase)
cls = MODEL_REGISTRY[model_name.lower()]
assert issubclass(cls, Wav2LipBase)

model = cls(**kwargs)
checkpoint = _load(checkpoint, device)
Expand Down
2 changes: 1 addition & 1 deletion nota_wav2lip/preprocess/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loguru import logger

import face_detection
from nota_wav2lip.preprocess.ffmpeg import FFMPEG_LOGGING_MODE
from nota_wav2lip.util import FFMPEG_LOGGING_MODE

detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cpu')
PADDING = [0, 10, 0, 0]
Expand Down
2 changes: 1 addition & 1 deletion nota_wav2lip/preprocess/lrs3_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from loguru import logger

from nota_wav2lip.preprocess.ffmpeg import FFMPEG_LOGGING_MODE
from nota_wav2lip.util import FFMPEG_LOGGING_MODE

class LabelInfo(TypedDict):
text: str
Expand Down
File renamed without changes.
7 changes: 1 addition & 6 deletions nota_wav2lip/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class VideoSlicer:
def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str], video_path: Union[Path, str]):
def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str]):
self.fps = hp.face.video_fps
self.frame_dir = frame_dir
self.frame_path_list = sorted(Path(self.frame_dir).glob("*.jpg"))
Expand All @@ -21,11 +21,6 @@ def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str], vid
self.bbox: List[List[int]] = [metadata['bbox'][key] for key in sorted(metadata['bbox'].keys())]
self.bbox_format = metadata['format']
assert len(self.bbox) == len(self.frame_array_list)
self._video_path = video_path

@property
def video_path(self):
return self._video_path

def __len__(self):
return len(self.frame_array_list)
Expand Down

0 comments on commit 0c39038

Please sign in to comment.