forked from resemble-ai/resemble-enhance
-
Notifications
You must be signed in to change notification settings - Fork 0
/
deepspeed_installer.py
65 lines (48 loc) · 2.4 KB
/
deepspeed_installer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import subprocess
import sys
import importlib.metadata as metadata # Use importlib.metadata
from packaging import version
from loguru import logger
def install_package(package_link):
subprocess.check_call([sys.executable, "-m", "pip", "install", package_link])
def is_package_installed(package_name):
try:
metadata.version(package_name)
return True
except metadata.PackageNotFoundError:
return False
def check_and_install_torch():
required_torch_version = 'torch==2.1.1+cu118 torchaudio==2.1.1+cu118'
# Check if torch with CUDA 11.8 is installed.
if not any(required_torch_version in pkg for pkg in metadata.distributions()):
logger.info(f"'{required_torch_version}' not found. Installing...")
subprocess.check_call([sys.executable, "-m", "pip", "install", required_torch_version,"--index-url https://download.pytorch.org/whl/cu118"])
else:
logger.info(f"'{required_torch_version}' already installed.")
def install_deepspeed_based_on_python_version():
# check_and_install_torch()
if not is_package_installed('deepspeed'):
python_version = sys.version_info
logger.info(f"Python version: {python_version.major}.{python_version.minor}")
# Define your package links here
py310_win = "https://github.com/daswer123/xtts-webui/releases/download/deepspeed/deepspeed-0.11.2+cuda118-cp310-cp310-win_amd64.whl"
py311_win = "https://github.com/daswer123/xtts-webui/releases/download/deepspeed/deepspeed-0.11.2+cuda118-cp311-cp311-win_amd64.whl"
# Use generic pip install deepspeed for Linux or custom wheels for Windows.
deepspeed_link = None
if sys.platform == 'win32':
if python_version.major == 3 and python_version.minor == 10:
deepspeed_link = py310_win
elif python_version.major == 3 and python_version.minor == 11:
deepspeed_link = py311_win
else:
logger.error("Unsupported Python version on Windows.")
else: # Assuming Linux/MacOS otherwise (add specific checks if necessary)
deepspeed_link = 'deepspeed==0.11.2'
if deepspeed_link:
logger.info("Installing DeepSpeed...")
install_package(deepspeed_link)
# else:
# logger.info("'deepspeed' already installed.")
if __name__ == "__main__":
install_deepspeed_based_on_python_version()