Skip to content

Commit 82e37c0

Browse files
committed
a few improvements
1 parent 523c051 commit 82e37c0

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

src/tabpfn/misc/onnx_wrapper.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def export_model(
249249

250250
# Define dynamic axes for variable input sizes
251251
dynamic_axes = {
252-
"X": {0: "num_datapoints", 1: "batch_size", 2: "num_features"},
252+
"X": {0: "num_datapoints", 2: "num_features"},
253253
"y": {0: "num_labels"},
254254
"single_eval_pos": {},
255255
"only_return_standard_out": {},
@@ -291,9 +291,10 @@ def check_input_names(model_path: str) -> None:
291291
model_path: The path to the ONNX model file.
292292
"""
293293
onnx.load(model_path)
294-
# get input names from graph
295-
graph = onnx.load(model_path).graph
296-
[input_node.name for input_node in graph.input]
294+
295+
# Print input names
296+
297+
# Print output names
297298

298299

299300
if __name__ == "__main__":

src/tabpfn/model/encoders.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def normalize_data(
9090
mean = torch_nanmean(data, axis=0) # type: ignore
9191
std = torch_nanstd(data, axis=0) + 1e-20
9292

93-
if len(data) == 1 or normalize_positions == 1:
93+
if data.shape[0] == 1 or normalize_positions == 1:
9494
std[:] = 1.0
9595

9696
if std_only:

0 commit comments

Comments
 (0)