diff --git a/tests/nn_tests.py b/tests/nn_tests.py index b9d657cd..fcb6dd97 100644 --- a/tests/nn_tests.py +++ b/tests/nn_tests.py @@ -59,17 +59,6 @@ def test_nn_conf_dict(): clf = TPOTClassifier(config_dict=classifier_config_nn) assert clf.config_dict == classifier_config_nn -def test_nn_errors_on_multiclass(): - """Assert that TPOT-NN throws an error when you try to pass training data with > 2 classes. (NN)""" - clf = TPOTClassifier( - random_state=42, - population_size=1, - generations=1, - config_dict=classifier_config_nn, - template='PytorchLRClassifier' - ) - assert_raises(ValueError, clf.fit, multiclass_X, multiclass_y) - def test_pytorch_lr_classifier(): """Assert that the PytorchLRClassifier model works. (NN)""" clf = nn.PytorchLRClassifier(