-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsetup.py
More file actions
187 lines (165 loc) · 5.97 KB
/
setup.py
File metadata and controls
187 lines (165 loc) · 5.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import glob
import os
import platform
import shutil
from pathlib import Path
from setuptools import setup, find_packages
from torch.utils.cpp_extension import CppExtension, BuildExtension
# Try to load .env for local development
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
ROOT = Path(__file__).parent.absolute()
# Dawn prefix - required for building
DAWN_PREFIX = os.environ.get("DAWN_PREFIX")
if not DAWN_PREFIX:
# Check common locations
possible_paths = [
ROOT / "dawn-install",
ROOT / "dawn" / "install" / "Release",
Path.home() / "dawn" / "install" / "Release",
]
for p in possible_paths:
if p.exists():
DAWN_PREFIX = str(p)
break
if not DAWN_PREFIX:
raise RuntimeError(
"DAWN_PREFIX environment variable not set and Dawn not found in common locations.\n"
"Please set DAWN_PREFIX to the Dawn installation directory, or run:\n"
" ./scripts/build-dawn.sh"
)
DAWN_PREFIX = Path(DAWN_PREFIX)
def get_dawn_library_name():
"""Get the Dawn library filename for the current platform."""
system = platform.system()
if system == "Linux":
return "libwebgpu_dawn.so"
elif system == "Darwin":
return "libwebgpu_dawn.dylib"
elif system == "Windows":
return "webgpu_dawn.dll"
else:
raise RuntimeError(f"Unsupported platform: {system}")
def get_dawn_lib_path():
"""Get the path to the Dawn shared library, or None if statically linked."""
lib_name = get_dawn_library_name()
lib_path = DAWN_PREFIX / "lib" / lib_name
if lib_path.exists():
return lib_path
# Check if static library exists (static linking)
static_name = lib_name.replace(".so", ".a").replace(".dylib", ".a").replace(".dll", ".lib")
static_path = DAWN_PREFIX / "lib" / static_name
if static_path.exists():
print(f"Dawn static library found at {static_path}, using static linking")
return None
raise RuntimeError(f"Dawn library not found at {lib_path} or {static_path}")
class BuildExtWithDawn(BuildExtension):
"""Custom build extension that copies Dawn library into the package."""
def run(self):
# Run the normal build
super().run()
# Copy Dawn library into the built package (only for dynamic linking)
lib_path = get_dawn_lib_path()
if lib_path is None:
# Static linking - no need to copy library
print("Using static Dawn library, skipping library copy")
return
# Find the built extension directory
for output in self.get_outputs():
output_dir = Path(output).parent
break
else:
output_dir = Path(self.build_lib) / "torch_webgpu"
# Create libs directory in the package
libs_dir = output_dir / "libs"
libs_dir.mkdir(exist_ok=True)
# Copy the Dawn library
dst = libs_dir / lib_path.name
print(f"Copying {lib_path} -> {dst}")
shutil.copy2(lib_path, dst)
# On Linux, we need to set the RPATH so the extension can find the library
system = platform.system()
if system == "Linux":
import subprocess
for output in self.get_outputs():
if output.endswith(".so"):
# Set RPATH to look in the libs directory
subprocess.run([
"patchelf", "--set-rpath", "$ORIGIN/libs",
output
], check=False)
elif system == "Darwin":
import subprocess
for output in self.get_outputs():
if output.endswith(".so") or output.endswith(".dylib"):
# Update the library path
subprocess.run([
"install_name_tool", "-add_rpath", "@loader_path/libs",
output
], check=False)
# Determine extra compile/link args based on platform
extra_compile_args = []
extra_link_args = []
system = platform.system()
if system == "Linux":
extra_compile_args = ["-std=c++17", "-O2"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/libs"]
elif system == "Darwin":
extra_compile_args = ["-std=c++17", "-O2"]
extra_link_args = ["-Wl,-rpath,@loader_path/libs"]
elif system == "Windows":
extra_compile_args = ["/std:c++17", "/O2"]
setup(
name="torch-webgpu",
version="0.0.1",
description="WebGPU backend for PyTorch",
long_description=Path("README.md").read_text(encoding="utf-8"),
long_description_content_type="text/markdown",
author="Jedrzej Maczan",
author_email="jedrzejpawel@maczan.pl",
url="https://github.com/jmaczan/torch-webgpu",
ext_modules=[
CppExtension(
name="torch_webgpu._C",
sources=glob.glob("csrc/**/*.cpp", recursive=True),
include_dirs=[
str(ROOT / "csrc"),
str(DAWN_PREFIX / "include"),
],
library_dirs=[
str(DAWN_PREFIX / "lib"),
],
libraries=[
"webgpu_dawn",
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
),
],
cmdclass={"build_ext": BuildExtWithDawn},
package_dir={"": "python"},
packages=find_packages(where="python"),
package_data={
"torch_webgpu": ["libs/*"],
},
include_package_data=True,
python_requires=">=3.10",
install_requires=[
"torch>=2.0.0",
"numpy",
],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
)