Skip to content

Commit

Permalink
Merge pull request pycaret#4029 from kondziolka9ld/fix_parallel_mode_…
Browse files Browse the repository at this point in the history
…no_model_trained

Fix: In parallel mode return empty list of models when no models was trained.
  • Loading branch information
Yard1 authored Aug 1, 2024
2 parents 10dc1a4 + 07f9dc8 commit 85efec6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
7 changes: 6 additions & 1 deletion pycaret/parallel/fugue_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ def compare_models(
as_fugue=True,
as_local=True,
).as_array()
res = pd.concat(cloudpickle.loads(x[0]) for x in outputs)

pd_dataframe_for_models = [cloudpickle.loads(x[0]) for x in outputs]
if all(pd_dataframe_for_model.empty for pd_dataframe_for_model in pd_dataframe_for_models):
return []

res = pd.concat(pd_dataframe_for_models)
res = res.sort_values(sort_col, ascending=asc)
top = res.head(self._params.get("n_select", 1))
instance._display_container.append(res.iloc[:, :-1])
Expand Down
24 changes: 22 additions & 2 deletions tests/test_classification_parallel.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import pycaret.classification as pc
from pycaret.datasets import get_data
from pycaret.parallel import FugueBackend


def _score_dummy(y_true, y_prob, axis=0):
return 0.0


def test_classification_parallel():
from pycaret.parallel import FugueBackend

pc.setup(
data_func=lambda: get_data("juice", verbose=False),
target="Purchase",
Expand Down Expand Up @@ -46,3 +45,24 @@ def test_classification_parallel():

pc.compare_models(n_select=2, sort="DUMMY", parallel=be)
pc.pull()


def test_classification_parallel_returns_empty_models_list_when_no_model_is_trained():
pc.setup(
data_func=lambda: get_data("juice", verbose=False),
target="Purchase",
session_id=0,
n_jobs=1,
verbose=False,
html=False,
)

fconf = {
"fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer",
"fugue.rpc.flask_server.host": "localhost",
"fugue.rpc.flask_server.port": "3333",
"fugue.rpc.flask_server.timeout": "2 sec",
}

res = pc.compare_models(include=[], parallel=FugueBackend("dask", fconf, display_remote=True, batch_size=3, top_only=False))
assert (len(res) == 0)

0 comments on commit 85efec6

Please sign in to comment.