-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathloader.py
165 lines (139 loc) · 6.06 KB
/
loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import os
import logging
import torch
import asyncio
import aiohttp
import requests
from huggingface_hub import hf_hub_download
# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Configuration
DATA_ROOT = os.environ.get('DATA_ROOT', '/tmp/data')
MODELS_DIR = os.path.join(DATA_ROOT, "models")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Hugging Face repository information
HF_REPO_ID = "jbilcke-hf/model-cocktail"
# Model files to download
MODEL_FILES = [
"dwpose/dw-ll_ucoco_384.pth",
"face-detector/s3fd-619a316812.pth",
"liveportrait/spade_generator.pth",
"liveportrait/warping_module.pth",
"liveportrait/motion_extractor.pth",
"liveportrait/stitching_retargeting_module.pth",
"liveportrait/appearance_feature_extractor.pth",
"liveportrait/landmark.onnx",
# For animal mode 🐶🐱
# however they say animal mode doesn't support stitching yet?
# https://github.com/KwaiVGI/LivePortrait/blob/main/assets/docs/changelog/2024-08-02.md#updates-on-animals-mode
#"liveportrait-animals/warping_module.pth",
#"liveportrait-animals/spade_generator.pth",
#"liveportrait-animals/motion_extractor.pth",
#"liveportrait-animals/appearance_feature_extractor.pth",
#"liveportrait-animals/stitching_retargeting_module.pth",
#"liveportrait-animals/xpose.pth",
# this is a hack, instead we should probably try to
# fix liveportrait/utils/dependencies/insightface/utils/storage.py
"insightface/models/buffalo_l.zip",
"insightface/buffalo_l/det_10g.onnx",
"insightface/buffalo_l/2d106det.onnx",
"sd-vae-ft-mse/diffusion_pytorch_model.bin",
"sd-vae-ft-mse/diffusion_pytorch_model.safetensors",
"sd-vae-ft-mse/config.json",
# we don't use those yet
#"flux-dev/flux-dev-fp8.safetensors",
#"flux-dev/flux_dev_quantization_map.json",
#"pulid-flux/pulid_flux_v0.9.0.safetensors",
#"pulid-flux/pulid_v1.bin"
]
def create_directory(directory):
"""Create a directory if it doesn't exist and log its status."""
if not os.path.exists(directory):
os.makedirs(directory)
logger.info(f" Directory created: {directory}")
else:
logger.info(f" Directory already exists: {directory}")
def print_directory_structure(startpath):
"""Print the directory structure starting from the given path."""
for root, dirs, files in os.walk(startpath):
level = root.replace(startpath, '').count(os.sep)
indent = ' ' * 4 * level
logger.info(f"{indent}{os.path.basename(root)}/")
subindent = ' ' * 4 * (level + 1)
for f in files:
logger.info(f"{subindent}{f}")
async def download_hf_file(filename: str) -> None:
"""Download a file from Hugging Face to the models directory."""
dest = os.path.join(MODELS_DIR, filename)
os.makedirs(os.path.dirname(dest), exist_ok=True)
if os.path.exists(dest):
# this is really for debugging purposes only
logger.debug(f" ✅ {filename}")
return
logger.info(f" ⏳ Downloading {HF_REPO_ID}/{filename}")
try:
await asyncio.get_event_loop().run_in_executor(
None,
lambda: hf_hub_download(
repo_id=HF_REPO_ID,
filename=filename,
local_dir=MODELS_DIR
)
)
logger.info(f" ✅ Downloaded {filename}")
except Exception as e:
logger.error(f"🚨 Error downloading file from Hugging Face: {e}")
if os.path.exists(dest):
os.remove(dest)
raise
async def download_all_models():
"""Download all required models from the Hugging Face repository."""
logger.info(" 🔎 Looking for models...")
tasks = [download_hf_file(filename) for filename in MODEL_FILES]
await asyncio.gather(*tasks)
logger.info(" ✅ All models are available")
# are you looking to debug the app and verify that models are downloaded properly?
# then un-comment the two following lines:
#logger.info("💡 Printing directory structure of models:")
#print_directory_structure(MODELS_DIR)
class ModelLoader:
"""A class responsible for loading and initializing all required models."""
def __init__(self):
self.device = DEVICE
self.models_dir = MODELS_DIR
async def load_live_portrait(self):
"""Load LivePortrait models."""
from liveportrait.config.inference_config import InferenceConfig
from liveportrait.config.crop_config import CropConfig
from liveportrait.live_portrait_pipeline import LivePortraitPipeline
logger.info(" ⏳ Loading LivePortrait models...")
live_portrait_pipeline = await asyncio.to_thread(
LivePortraitPipeline,
inference_cfg=InferenceConfig(
# default values
flag_stitching=True, # we recommend setting it to True!
flag_relative=True, # whether to use relative motion
flag_pasteback=True, # whether to paste-back/stitch the animated face cropping from the face-cropping space to the original image space
flag_do_crop= True, # whether to crop the source portrait to the face-cropping space
flag_do_rot=True, # whether to conduct the rotation when flag_do_crop is True
),
crop_cfg=CropConfig()
)
logger.info(" ✅ LivePortrait models loaded successfully.")
return live_portrait_pipeline
async def initialize_models():
"""Initialize and load all required models."""
logger.info("🚀 Starting model initialization...")
# Ensure all required models are downloaded
await download_all_models()
# Initialize the ModelLoader
loader = ModelLoader()
# Load LivePortrait models
live_portrait = await loader.load_live_portrait()
logger.info("✅ Model initialization completed.")
return live_portrait
# Initial setup
logger.info("🚀 Setting up storage directories...")
create_directory(MODELS_DIR)
logger.info("✅ Storage directories setup completed.")