Skip to content

Commit 523c051

Browse files
committed
only init onnx session once
1 parent 720dc16 commit 523c051

File tree

4 files changed

+43
-31
lines changed

4 files changed

+43
-31
lines changed

src/tabpfn/base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,13 @@ def initialize_tabpfn_model(
114114

115115
def load_onnx_model(
116116
model_path: str | Path,
117+
device: torch.device,
117118
) -> ONNXModelWrapper:
118119
"""Load a TabPFN model in ONNX format.
119120
120121
Args:
121122
model_path: Path to the ONNX model file.
123+
device: The device to run the model on.
122124
123125
Returns:
124126
The loaded ONNX model wrapped in a PyTorch-compatible interface.
@@ -139,7 +141,7 @@ def load_onnx_model(
139141
if not model_path.exists():
140142
raise FileNotFoundError(f"ONNX model not found at: {model_path}")
141143

142-
return ONNXModelWrapper(str(model_path))
144+
return ONNXModelWrapper(str(model_path), device)
143145

144146

145147
def determine_precision(

src/tabpfn/classifier.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,15 @@ def fit(self, X: XType, y: YType) -> Self:
389389
"""
390390
static_seed, rng = infer_random_state(self.random_state)
391391

392+
# Determine device and precision
393+
self.device_ = infer_device_and_type(self.device)
394+
(self.use_autocast_, self.forced_inference_dtype_, byte_size) = (
395+
determine_precision(self.inference_precision, self.device_)
396+
)
397+
392398
# Load the model and config
393399
if self.use_onnx:
394-
self.model_ = load_onnx_model("model_classifier.onnx")
400+
self.model_ = load_onnx_model("model_classifier.onnx", self.device_)
395401
else:
396402
self.model_, self.config_, _ = initialize_tabpfn_model(
397403
model_path=self.model_path,
@@ -400,12 +406,6 @@ def fit(self, X: XType, y: YType) -> Self:
400406
static_seed=static_seed,
401407
)
402408

403-
# Determine device and precision
404-
self.device_ = infer_device_and_type(self.device)
405-
(self.use_autocast_, self.forced_inference_dtype_, byte_size) = (
406-
determine_precision(self.inference_precision, self.device_)
407-
)
408-
409409
# Build the interface_config
410410
self.interface_config_ = ModelInterfaceConfig.from_user_input(
411411
inference_config=self.inference_config,

src/tabpfn/misc/onnx_wrapper.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@
2121
class ONNXModelWrapper:
2222
"""Wrap ONNX model to match the PyTorch model interface."""
2323

24-
def __init__(self, model_path: str):
24+
def __init__(self, model_path: str, device: torch.device):
2525
"""Initialize the ONNX model wrapper.
2626
2727
Args:
2828
model_path: Path to the ONNX model file.
29+
device: The device to run the model on.
2930
"""
3031
self.model_path = model_path
31-
self.providers = ["CPUExecutionProvider"]
32+
self.device = device
33+
if device.type == "cuda":
34+
self.providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
35+
elif device.type == "cpu":
36+
self.providers = ["CPUExecutionProvider"]
37+
else:
38+
raise ValueError(f"Invalid device: {device}")
3239
self.session = ort.InferenceSession(
3340
model_path,
3441
providers=self.providers,
@@ -46,24 +53,27 @@ def to(
4653
Returns:
4754
self
4855
"""
49-
if device.type == "cuda":
50-
# Check if CUDA is available in ONNX Runtime
51-
cuda_provider = "CUDAExecutionProvider"
52-
if cuda_provider in ort.get_available_providers():
53-
self.providers = [cuda_provider, "CPUExecutionProvider"]
54-
# Reinitialize session with CUDA provider
56+
# Only recreate session if device type has changed
57+
if device.type != self.device.type:
58+
if device.type == "cuda":
59+
# Check if CUDA is available in ONNX Runtime
60+
cuda_provider = "CUDAExecutionProvider"
61+
if cuda_provider in ort.get_available_providers():
62+
self.providers = [cuda_provider, "CPUExecutionProvider"]
63+
# Reinitialize session with CUDA provider
64+
self.session = ort.InferenceSession(
65+
self.model_path,
66+
providers=self.providers,
67+
)
68+
# If CUDA is not available, keep current session
69+
else:
70+
self.providers = ["CPUExecutionProvider"]
5571
self.session = ort.InferenceSession(
5672
self.model_path,
5773
providers=self.providers,
5874
)
59-
else:
60-
pass
61-
else:
62-
self.providers = ["CPUExecutionProvider"]
63-
self.session = ort.InferenceSession(
64-
self.model_path,
65-
providers=self.providers,
66-
)
75+
# Update the device
76+
self.device = device
6777
return self
6878

6979
def type(

src/tabpfn/regressor.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,15 @@ def fit(self, X: XType, y: YType) -> Self:
401401
"""
402402
static_seed, rng = infer_random_state(self.random_state)
403403

404+
# Determine device and precision
405+
self.device_ = infer_device_and_type(self.device)
406+
(self.use_autocast_, self.forced_inference_dtype_, byte_size) = (
407+
determine_precision(self.inference_precision, self.device_)
408+
)
409+
404410
# Load the model and config
405411
if self.use_onnx:
406-
self.model_ = load_onnx_model("model_regressor.onnx")
412+
self.model_ = load_onnx_model("model_regressor.onnx", self.device_)
407413
# Initialize bardist_ for ONNX mode
408414
# TODO: faster way to do this
409415
_, self.config_, self.bardist_ = initialize_tabpfn_model(
@@ -420,12 +426,6 @@ def fit(self, X: XType, y: YType) -> Self:
420426
static_seed=static_seed,
421427
)
422428

423-
# Determine device and precision
424-
self.device_ = infer_device_and_type(self.device)
425-
(self.use_autocast_, self.forced_inference_dtype_, byte_size) = (
426-
determine_precision(self.inference_precision, self.device_)
427-
)
428-
429429
# Build the interface_config
430430
self.interface_config_ = ModelInterfaceConfig.from_user_input(
431431
inference_config=self.inference_config,

0 commit comments

Comments
 (0)