Skip to content

[BUG] No training possible with an explicit Validation Set provided #141

@alexisdurand

Description

@alexisdurand

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
from mambular.models import MambaTabClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

X, y = make_classification(n_samples=1000, n_features=10, n_classes=2 )
print( X.shape )

X_tr, X_val, y_tr, y_val = train_test_split( X, y, test_size=0.25 )
print( X_tr.shape, X_val.shape )

model = MambaTabClassifier()
model.fit(X_tr,y_tr, X_val=X_val, y_val=y_val)

Expected behavior
The model train explicitely with the train and valid datasets i provide and do not split the data provided

Screenshots
ValueError Traceback (most recent call last)
Cell In[84], line 12
9 print( X_tr.shape, X_val.shape )
11 model = MambaTabClassifier()
---> 12 model.fit(X_tr,y_tr, X_val=X_val, y_val=y_val)

File /opt/conda/lib/python3.10/site-packages/mambular/models/sklearn_base_classifier.py:324, in SklearnBaseClassifier.fit(self, X, y, val_size, X_val, y_val, max_epochs, random_state, batch_size, shuffle, patience, monitor, mode, lr, lr_patience, factor, weight_decay, checkpoint_path, dataloader_kwargs, rebuild, **trainer_kwargs)
322 if isinstance(y, pd.Series):
323 y = y.values
--> 324 if X_val:
325 if not isinstance(X_val, pd.DataFrame):
326 X_val = pd.DataFrame(X_val)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Desktop (please complete the following information):
Kaggle notebook

Additional context
Same thing for others mambular models using the sklearn wrapper

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions