Skip to content

Commit b0bc461

Browse files
authored
Merge pull request #184 from mert-kurttutan/main
Implement get_total_memory for windows
2 parents b953e87 + e2b402d commit b0bc461

9 files changed

+162
-53
lines changed

.github/workflows/pull_request.yml

+9-47
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,18 @@ jobs:
3333
# https://github.com/actions/setup-python/issues/855
3434
python-version: "3.9"
3535
dependency-set: minimum
36+
- os: windows-latest
37+
python-version: "3.9"
38+
dependency-set: minimum
3639
- os: ubuntu-latest
3740
python-version: "3.12"
3841
dependency-set: maximum
3942
- os: macos-latest
4043
python-version: "3.12"
4144
dependency-set: maximum
45+
- os: windows-latest
46+
python-version: "3.12"
47+
dependency-set: maximum
4248
runs-on: ${{ matrix.os }}
4349

4450
steps:
@@ -58,64 +64,20 @@ jobs:
5864
- name: Generate requirements file for minimum dependencies
5965
if: matrix.dependency-set == 'minimum'
6066
run: |
61-
python << EOF
62-
import re
63-
64-
with open('pyproject.toml', 'r') as f:
65-
content = f.read()
66-
67-
# Find dependencies section using regex
68-
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
69-
if deps_match:
70-
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
71-
min_reqs = []
72-
for dep in deps:
73-
match = re.match(r'([^>=<\s]+)\s*>=\s*([^,\s"\']+)', dep)
74-
if match:
75-
package, min_ver = match.groups()
76-
min_reqs.append(f"{package}=={min_ver}")
77-
78-
with open('requirements.txt', 'w') as f:
79-
f.write('\n'.join(min_reqs))
80-
EOF
67+
python scripts/get_min_dependencies.py
8168
8269
- name: Generate requirements file for maximum dependencies
8370
if: matrix.dependency-set == 'maximum'
8471
run: |
85-
python << EOF
86-
import re
87-
88-
with open('pyproject.toml', 'r') as f:
89-
content = f.read()
90-
91-
# Find dependencies section using regex
92-
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
93-
if deps_match:
94-
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
95-
max_reqs = []
96-
for dep in deps:
97-
# Check for maximum version constraint
98-
max_version_match = re.search(r'([^>=<\s]+).*?<\s*([^,\s"\']+)', dep)
99-
if max_version_match:
100-
# If there's a max version, use the version just below it
101-
package, max_ver = max_version_match.groups()
102-
max_reqs.append(f"{package}<{max_ver}")
103-
else:
104-
# If no max version, just use the package name
105-
package = re.match(r'([^>=<\s]+)', dep).group(1)
106-
max_reqs.append(package)
107-
108-
with open('requirements.txt', 'w') as f:
109-
f.write('\n'.join(max_reqs))
110-
EOF
72+
python scripts/get_max_dependencies.py
11173
11274
- name: Install dependencies
11375
run: |
11476
uv pip install --system --no-deps .
11577
# onnx is required for onnx export tests
11678
# we don't install all dev dependencies here for speed
11779
uv pip install --system -r requirements.txt
118-
uv pip install --system pytest onnx
80+
uv pip install --system pytest onnx psutil
11981
12082
- name: Initialize submodules
12183
run: git submodule update --init --recursive

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ dev = [
6262
# Test
6363
"pytest",
6464
"onnx", # required for onnx export tests
65+
"psutil", # required for testing internal memory tool on windows
6566
# Docs
6667
"mkdocs",
6768
"mkdocs-material",

scripts/get_max_dependencies.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import re
2+
3+
def main() -> None:
4+
with open('pyproject.toml', 'r') as f:
5+
content = f.read()
6+
7+
# Find dependencies section using regex
8+
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
9+
if deps_match:
10+
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
11+
max_reqs = []
12+
for dep in deps:
13+
# Check for maximum version constraint
14+
max_version_match = re.search(r'([^>=<\s]+).*?<\s*([^,\s"\']+)', dep)
15+
if max_version_match:
16+
# If there's a max version, use the version just below it
17+
package, max_ver = max_version_match.groups()
18+
max_reqs.append(f"{package}<{max_ver}")
19+
else:
20+
# If no max version, just use the package name
21+
package = re.match(r'([^>=<\s]+)', dep).group(1)
22+
max_reqs.append(package)
23+
24+
with open('requirements.txt', 'w') as f:
25+
f.write('\n'.join(max_reqs))
26+
27+
if __name__ == '__main__':
28+
main()

scripts/get_min_dependencies.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import re
2+
3+
def main() -> None:
4+
with open('pyproject.toml', 'r') as f:
5+
content = f.read()
6+
7+
# Find dependencies section using regex
8+
deps_match = re.search(r'dependencies\s*=\s*\[(.*?)\]', content, re.DOTALL)
9+
if deps_match:
10+
deps = [d.strip(' "\'') for d in deps_match.group(1).strip().split('\n') if d.strip()]
11+
min_reqs = []
12+
for dep in deps:
13+
match = re.match(r'([^>=<\s]+)\s*>=\s*([^,\s"\']+)', dep)
14+
if match:
15+
package, min_ver = match.groups()
16+
min_reqs.append(f"{package}=={min_ver}")
17+
18+
with open('requirements.txt', 'w') as f:
19+
f.write('\n'.join(min_reqs))
20+
21+
if __name__ == '__main__':
22+
main()

src/tabpfn/model/memory.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,23 @@ def get_max_free_memory(
252252
os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES") / 1e9
253253
)
254254
except AttributeError:
255-
# TODO: `os.sysconf` does not exist on windows.
256-
free_memory = cls.convert_units(
257-
default_gb_cpu_if_failed_to_calculate,
258-
"gb",
259-
"b",
260-
)
255+
from tabpfn.utils import get_total_memory_windows
256+
257+
if os.name == "nt":
258+
free_memory = get_total_memory_windows()
259+
else:
260+
warnings.warn(
261+
"Could not get system memory, defaulting to"
262+
f" {default_gb_cpu_if_failed_to_calculate} GB",
263+
RuntimeWarning,
264+
stacklevel=2,
265+
)
266+
free_memory = cls.convert_units(
267+
default_gb_cpu_if_failed_to_calculate,
268+
"gb",
269+
"b",
270+
)
271+
261272
except ValueError:
262273
warnings.warn(
263274
"Could not get system memory, defaulting to"

src/tabpfn/utils.py

+38
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44

55
from __future__ import annotations
66

7+
import ctypes
78
import os
89
import sys
10+
import typing
911
import warnings
1012
from collections.abc import Sequence
1113
from pathlib import Path
@@ -829,3 +831,39 @@ def _transform_borders_one(
829831
)
830832

831833
return logit_cancel_mask, descending_borders, borders_t
834+
835+
836+
# Terminology: Use memory to referent physical memory, swap for swap memory
837+
def get_total_memory_windows() -> float:
838+
"""Get the total memory of the system for windows OS, using windows API.
839+
840+
Returns:
841+
The total memory of the system in GB.
842+
"""
843+
844+
# ref: https://github.com/microsoft/windows-rs/blob/c9177f7a65c764c237a9aebbd3803de683bedaab/crates/tests/bindgen/src/fn_return_void_sys.rs#L12
845+
# ref: https://learn.microsoft.com/en-us/windows/win32/api/sysinfoapi/ns-sysinfoapi-memorystatusex
846+
# this class is needed to load the memory status with GlobalMemoryStatusEx function
847+
# using win32 API, for more details see microsoft docs link above
848+
class _MEMORYSTATUSEX(ctypes.Structure):
849+
_fields_: typing.ClassVar = [
850+
("dwLength", ctypes.c_ulong),
851+
("dwMemoryLoad", ctypes.c_ulong),
852+
("ullTotalPhys", ctypes.c_ulonglong),
853+
("ullAvailPhys", ctypes.c_ulonglong),
854+
("ullTotalPageFile", ctypes.c_ulonglong),
855+
("ullAvailPageFile", ctypes.c_ulonglong),
856+
("ullTotalVirtual", ctypes.c_ulonglong),
857+
("ullAvailVirtual", ctypes.c_ulonglong),
858+
("ullAvailExtendedVirtual", ctypes.c_ulonglong),
859+
]
860+
861+
# Initialize the structure
862+
mem_status = _MEMORYSTATUSEX()
863+
# need to initialize lenght of structure, see microsft docs above
864+
mem_status.dwLength = ctypes.sizeof(_MEMORYSTATUSEX)
865+
866+
k32_lib = ctypes.windll.LoadLibrary("kernel32.dll")
867+
k32_lib.GlobalMemoryStatusEx(ctypes.byref(mem_status))
868+
869+
return mem_status.ullTotalPhys / 1e9 # Convert bytes to GB

tests/test_classifier_interface.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import io
4+
import os
45
from itertools import product
56
from typing import Callable, Literal
67

@@ -248,6 +249,8 @@ def forward(
248249

249250
@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
250251
def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None:
252+
if os.name == "nt":
253+
pytest.skip("onnx export is not tested on windows")
251254
X, y = X_y
252255
with torch.no_grad():
253256
classifier = TabPFNClassifier(n_estimators=1, device="cpu", random_state=42)

tests/test_regressor_interface.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import io
4+
import os
45
from itertools import product
56
from typing import Callable, Literal
67

@@ -246,6 +247,8 @@ def forward(
246247
# WARNING: unstable for scipy<1.11.0
247248
@pytest.mark.filterwarnings("ignore::torch.jit.TracerWarning")
248249
def test_onnx_exportable_cpu(X_y: tuple[np.ndarray, np.ndarray]) -> None:
250+
if os.name == "nt":
251+
pytest.skip("onnx export is not tested on windows")
249252
X, y = X_y
250253
with torch.no_grad():
251254
regressor = TabPFNRegressor(n_estimators=1, device="cpu", random_state=43)

tests/test_utils.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# use get_total_memory and compare it against result from psutils
2+
# run it only if the it is windows os.name == "nt"
3+
from __future__ import annotations
4+
5+
import os
6+
7+
8+
def test_internal_windows_total_memory():
9+
if os.name == "nt":
10+
import psutil
11+
12+
from tabpfn.utils import get_total_memory_windows
13+
14+
utils_result = get_total_memory_windows()
15+
psutil_result = psutil.virtual_memory().total / 1e9
16+
assert utils_result == psutil_result
17+
18+
19+
def test_internal_windows_total_memory_multithreaded():
20+
# collect results from multiple threads
21+
if os.name == "nt":
22+
import threading
23+
24+
import psutil
25+
26+
from tabpfn.utils import get_total_memory_windows
27+
28+
results = []
29+
30+
def get_memory():
31+
results.append(get_total_memory_windows())
32+
33+
threads = []
34+
for _ in range(10):
35+
t = threading.Thread(target=get_memory)
36+
threads.append(t)
37+
t.start()
38+
for t in threads:
39+
t.join()
40+
psutil_result = psutil.virtual_memory().total / 1e9
41+
assert all(result == psutil_result for result in results)

0 commit comments

Comments
 (0)