Skip to content

Commit 05ab7da

Browse files
authored
Update README.md
1 parent d84cd1a commit 05ab7da

File tree

1 file changed

+43
-46
lines changed

1 file changed

+43
-46
lines changed

README.md

+43-46
Original file line numberDiff line numberDiff line change
@@ -43,51 +43,6 @@ git clone https://github.com/PriorLabs/TabPFN.git
4343
pip install -e "TabPFN[dev]"
4444
```
4545

46-
### Offline Usage
47-
48-
TabPFN automatically downloads model weights when first used. For offline usage:
49-
50-
#### Manual Download
51-
52-
1. Download the model files manually from HuggingFace:
53-
- Classifier: [tabpfn-v2-classifier.ckpt](https://huggingface.co/Prior-Labs/TabPFN-v2-clf/resolve/main/tabpfn-v2-classifier.ckpt)
54-
- Regressor: [tabpfn-v2-regressor.ckpt](https://huggingface.co/Prior-Labs/TabPFN-v2-reg/resolve/main/tabpfn-v2-regressor.ckpt)
55-
56-
2. Place the file in one of these locations:
57-
- Specify directly: `TabPFNClassifier(model_path="/path/to/model.ckpt")`
58-
- Set environment variable: `os.environ["TABPFN_MODEL_CACHE_DIR"] = "/path/to/dir"`
59-
- Default OS cache directory:
60-
- Windows: `%APPDATA%\tabpfn\`
61-
- macOS: `~/Library/Caches/tabpfn/`
62-
- Linux: `~/.cache/tabpfn/`
63-
64-
#### Quick Download Script
65-
66-
```python
67-
import requests
68-
from tabpfn.utils import _user_cache_dir
69-
import sys
70-
71-
# Get default cache directory using TabPFN's internal function
72-
cache_dir = _user_cache_dir(platform=sys.platform)
73-
cache_dir.mkdir(parents=True, exist_ok=True)
74-
75-
# Define models to download
76-
models = {
77-
"tabpfn-v2-classifier.ckpt": "https://huggingface.co/Prior-Labs/TabPFN-v2-clf/resolve/main/tabpfn-v2-classifier.ckpt",
78-
"tabpfn-v2-regressor.ckpt": "https://huggingface.co/Prior-Labs/TabPFN-v2-reg/resolve/main/tabpfn-v2-regressor.ckpt",
79-
}
80-
81-
# Download each model
82-
for name, url in models.items():
83-
path = cache_dir / name
84-
print(f"Downloading {name} to {path}")
85-
with open(path, "wb") as f:
86-
f.write(requests.get(url).content)
87-
88-
print(f"Models downloaded to {cache_dir}")
89-
```
90-
9146
### Basic Usage
9247

9348
#### Classification
@@ -231,7 +186,49 @@ A: TabPFN v2 requires **Python 3.9+** due to newer language features. Compatible
231186
### **Installation & Setup**
232187

233188
**Q: How do I use TabPFN without an internet connection?**
234-
A: Manually download the model weights from [Hugging Face](https://huggingface.co/Prior-Labs/) and place them in your cache directory (see [Offline Usage](#offline-usage)).
189+
190+
TabPFN automatically downloads model weights when first used. For offline usage:
191+
192+
**Manual Download**
193+
194+
1. Download the model files manually from HuggingFace:
195+
- Classifier: [tabpfn-v2-classifier.ckpt](https://huggingface.co/Prior-Labs/TabPFN-v2-clf/resolve/main/tabpfn-v2-classifier.ckpt)
196+
- Regressor: [tabpfn-v2-regressor.ckpt](https://huggingface.co/Prior-Labs/TabPFN-v2-reg/resolve/main/tabpfn-v2-regressor.ckpt)
197+
198+
2. Place the file in one of these locations:
199+
- Specify directly: `TabPFNClassifier(model_path="/path/to/model.ckpt")`
200+
- Set environment variable: `os.environ["TABPFN_MODEL_CACHE_DIR"] = "/path/to/dir"`
201+
- Default OS cache directory:
202+
- Windows: `%APPDATA%\tabpfn\`
203+
- macOS: `~/Library/Caches/tabpfn/`
204+
- Linux: `~/.cache/tabpfn/`
205+
206+
**Quick Download Script**
207+
208+
```python
209+
import requests
210+
from tabpfn.utils import _user_cache_dir
211+
import sys
212+
213+
# Get default cache directory using TabPFN's internal function
214+
cache_dir = _user_cache_dir(platform=sys.platform)
215+
cache_dir.mkdir(parents=True, exist_ok=True)
216+
217+
# Define models to download
218+
models = {
219+
"tabpfn-v2-classifier.ckpt": "https://huggingface.co/Prior-Labs/TabPFN-v2-clf/resolve/main/tabpfn-v2-classifier.ckpt",
220+
"tabpfn-v2-regressor.ckpt": "https://huggingface.co/Prior-Labs/TabPFN-v2-reg/resolve/main/tabpfn-v2-regressor.ckpt",
221+
}
222+
223+
# Download each model
224+
for name, url in models.items():
225+
path = cache_dir / name
226+
print(f"Downloading {name} to {path}")
227+
with open(path, "wb") as f:
228+
f.write(requests.get(url).content)
229+
230+
print(f"Models downloaded to {cache_dir}")
231+
```
235232

236233
**Q: I'm getting a `pickle` error when loading the model. What should I do?**
237234
A: Try the following:

0 commit comments

Comments
 (0)