From 3fdc2fc19cea791e3cea56bea5d9662e65404306 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 10 Jan 2026 08:27:21 +0000 Subject: [PATCH 1/5] Initial plan From 4f665af70711aaadaacdcc4d16f89b3567aa280e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 10 Jan 2026 08:44:34 +0000 Subject: [PATCH 2/5] Add multi-target regression support - Modified validation to accept 2D y arrays (n_samples, n_targets) - Added multi-target detection in generic_task.validate_data - Filtered unsupported estimators (only XGBoost, CatBoost support multi-target) - Configured CatBoost with MultiRMSE objective for multi-target - Fixed AutoML.predict to not flatten multi-target predictions - Updated AutoML.fit docstring to document multi-target support Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --- flaml/automl/automl.py | 14 ++++-- flaml/automl/model.py | 12 ++++++ flaml/automl/task/generic_task.py | 61 ++++++++++++++++++++++++--- flaml/automl/task/task.py | 2 + flaml/automl/task/time_series_task.py | 2 +- 5 files changed, 79 insertions(+), 12 deletions(-) diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index 5cf2d71727..f2c77fec22 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -634,8 +634,11 @@ def predict( X = self._state.task.preprocess(X, self._transformer) y_pred = estimator.predict(X, **pred_kwargs) - if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1 and isinstance(y_pred, np.ndarray): - y_pred = y_pred.flatten() + # Only flatten if not multi-target regression + if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1: + is_multi_target = getattr(self._state, 'is_multi_target', False) + if not is_multi_target: + y_pred = y_pred.flatten() if self._label_transformer: return self._label_transformer.inverse_transform(Series(y_pred.astype(int))) else: @@ -1272,7 +1275,9 @@ def fit( must be the timestamp column (datetime type). Other columns in the dataframe are assumed to be exogenous variables (categorical or numeric). When using ray, X_train can be a ray.ObjectRef. - y_train: A numpy array or a pandas series of labels in shape (n, ). + y_train: A numpy array, pandas series, or pandas dataframe of labels in shape (n, ) + for single-target tasks or (n, k) for multi-target regression tasks. + For multi-target regression, only XGBoost and CatBoost estimators are supported. dataframe: A dataframe of training data including label column. For time series forecast tasks, dataframe must be specified and must have at least two columns, timestamp and label, where the first @@ -1883,7 +1888,8 @@ def is_to_reverse_metric(metric, task): self._state.error_metric = error_metric is_spark_dataframe = isinstance(X_train, psDataFrame) or isinstance(dataframe, psDataFrame) - estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe) + is_multi_target = getattr(self._state, 'is_multi_target', False) + estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe, is_multi_target) if is_spark_dataframe and self._use_spark: # For spark dataframe, use_spark must be False because spark models are trained in parallel themselves diff --git a/flaml/automl/model.py b/flaml/automl/model.py index 99fd9c0c61..7cb6d85c55 100644 --- a/flaml/automl/model.py +++ b/flaml/automl/model.py @@ -2081,6 +2081,18 @@ def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs): cat_features = list(X_train.select_dtypes(include="category").columns) else: cat_features = [] + + # Detect multi-target regression and set appropriate loss function + is_multi_target = False + if self._task.is_regression(): + if isinstance(y_train, np.ndarray) and y_train.ndim == 2 and y_train.shape[1] > 1: + is_multi_target = True + elif isinstance(y_train, DataFrame) and y_train.shape[1] > 1: + is_multi_target = True + + if is_multi_target and "loss_function" not in self.params: + self.params["loss_function"] = "MultiRMSE" + use_best_model = kwargs.get("use_best_model", True) n = max(int(len(y_train) * 0.9), len(y_train) - 1000) if use_best_model else len(y_train) X_tr, y_tr = X_train[:n], y_train[:n] diff --git a/flaml/automl/task/generic_task.py b/flaml/automl/task/generic_task.py index 5b74a3d755..6292dcb6d3 100644 --- a/flaml/automl/task/generic_task.py +++ b/flaml/automl/task/generic_task.py @@ -119,13 +119,15 @@ def validate_data( "a Scipy sparse matrix or a pyspark.pandas dataframe." ) assert isinstance( - y_train_all, (np.ndarray, pd.Series, psSeries) - ), "y_train_all must be a numpy array, a pandas series or a pyspark.pandas series." + y_train_all, (np.ndarray, pd.Series, pd.DataFrame, psSeries) + ), "y_train_all must be a numpy array, a pandas series, a pandas dataframe or a pyspark.pandas series." assert X_train_all.size != 0 and y_train_all.size != 0, "Input data must not be empty." if isinstance(X_train_all, np.ndarray) and len(X_train_all.shape) == 1: X_train_all = np.reshape(X_train_all, (X_train_all.size, 1)) if isinstance(y_train_all, np.ndarray): - y_train_all = y_train_all.flatten() + # Only flatten if it's truly 1D (not multi-target) + if y_train_all.ndim == 1 or (y_train_all.ndim == 2 and y_train_all.shape[1] == 1): + y_train_all = y_train_all.flatten() assert X_train_all.shape[0] == y_train_all.shape[0], "# rows in X_train must match length of y_train." if isinstance(X_train_all, psDataFrame): X_train_all = X_train_all.spark.cache() # cache data to improve compute speed @@ -219,6 +221,20 @@ def validate_data( automl._X_train_all.columns.to_list() if hasattr(automl._X_train_all, "columns") else None ) + # Detect multi-target regression + is_multi_target = False + n_targets = 1 + if self.is_regression(): + if isinstance(automl._y_train_all, np.ndarray) and automl._y_train_all.ndim == 2: + is_multi_target = True + n_targets = automl._y_train_all.shape[1] + elif isinstance(automl._y_train_all, pd.DataFrame): + is_multi_target = True + n_targets = automl._y_train_all.shape[1] + + state.is_multi_target = is_multi_target + state.n_targets = n_targets + automl._sample_weight_full = state.fit_kwargs.get( "sample_weight" ) # NOTE: _validate_data is before kwargs is updated to fit_kwargs_by_estimator @@ -227,14 +243,16 @@ def validate_data( "X_val must be None, a numpy array, a pandas dataframe, " "a Scipy sparse matrix or a pyspark.pandas dataframe." ) - assert isinstance(y_val, (np.ndarray, pd.Series, psSeries)), ( - "y_val must be None, a numpy array, a pandas series " "or a pyspark.pandas series." + assert isinstance(y_val, (np.ndarray, pd.Series, pd.DataFrame, psSeries)), ( + "y_val must be None, a numpy array, a pandas series, a pandas dataframe " "or a pyspark.pandas series." ) assert X_val.size != 0 and y_val.size != 0, ( "Validation data are expected to be nonempty. " "Use None for X_val and y_val if no validation data." ) if isinstance(y_val, np.ndarray): - y_val = y_val.flatten() + # Only flatten if it's truly 1D (not multi-target) + if y_val.ndim == 1 or (y_val.ndim == 2 and y_val.shape[1] == 1): + y_val = y_val.flatten() assert X_val.shape[0] == y_val.shape[0], "# rows in X_val must match length of y_val." if automl._transformer: state.X_val = automl._transformer.transform(X_val) @@ -819,7 +837,7 @@ def evaluate_model_CV( pred_time /= n return val_loss, metric, train_time, pred_time - def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool = False) -> List[str]: + def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool = False, is_multi_target: bool = False) -> List[str]: if "auto" != estimator_list: n_estimators = len(estimator_list) if is_spark_dataframe: @@ -848,6 +866,23 @@ def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: "Non-spark dataframes only support estimator names not ending with `_spark`. Non-supported " "estimators are removed." ) + + # Filter out unsupported estimators for multi-target regression + if is_multi_target and self.is_regression(): + # List of estimators that support multi-target regression natively + multi_target_supported = ["xgboost", "xgb_limitdepth", "catboost"] + original_len = len(estimator_list) + estimator_list = [est for est in estimator_list if est in multi_target_supported] + if len(estimator_list) == 0: + raise ValueError( + "Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. " + "Non-supported estimators are removed. No estimator is left." + ) + elif original_len != len(estimator_list): + logger.warning( + "Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. " + "Non-supported estimators are removed." + ) return estimator_list if self.is_rank(): estimator_list = ["lgbm", "xgboost", "xgb_limitdepth", "lgbm_spark"] @@ -897,6 +932,18 @@ def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: for est in estimator_list if (est.endswith("_spark") if is_spark_dataframe else not est.endswith("_spark")) ] + + # Filter for multi-target regression support + if is_multi_target and self.is_regression(): + # List of estimators that support multi-target regression natively + multi_target_supported = ["xgboost", "xgb_limitdepth", "catboost"] + estimator_list = [est for est in estimator_list if est in multi_target_supported] + if len(estimator_list) == 0: + raise ValueError( + "Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. " + "No supported estimator is available." + ) + return estimator_list def default_metric(self, metric: str) -> str: diff --git a/flaml/automl/task/task.py b/flaml/automl/task/task.py index 540f13fe81..0726ad1be6 100644 --- a/flaml/automl/task/task.py +++ b/flaml/automl/task/task.py @@ -253,6 +253,7 @@ def default_estimator_list( self, estimator_list: Union[List[str], str] = "auto", is_spark_dataframe: bool = False, + is_multi_target: bool = False, ) -> List[str]: """Return the list of default estimators registered for this task type. @@ -262,6 +263,7 @@ def default_estimator_list( Args: estimator_list: Either 'auto' or a list of estimator names to be validated. is_spark_dataframe: True if the data is a spark dataframe. + is_multi_target: True if the task involves multi-target regression. Returns: A list of valid estimator names for this task type. diff --git a/flaml/automl/task/time_series_task.py b/flaml/automl/task/time_series_task.py index 939b10d757..4cbc981ac4 100644 --- a/flaml/automl/task/time_series_task.py +++ b/flaml/automl/task/time_series_task.py @@ -459,7 +459,7 @@ def evaluate_model_CV( pred_time /= n return val_loss, metric, train_time, pred_time - def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool) -> List[str]: + def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool, is_multi_target: bool = False) -> List[str]: assert not is_spark_dataframe, "Spark is not yet supported for time series" # TODO: why not do this if/then in the calling function? From 82de6d34c44a110ab7997ab1332a6ac72d9519e9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 10 Jan 2026 08:48:32 +0000 Subject: [PATCH 3/5] Fix DataFrame label encoding and CV splitting for multi-target - Skip label encoding for multi-target regression DataFrames - Fix CV splitting to use iloc for DataFrame y-values - Add comprehensive multi-target regression tests Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --- flaml/automl/data.py | 13 +- flaml/automl/task/generic_task.py | 5 +- test/automl/test_multi_target.py | 198 ++++++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 test/automl/test_multi_target.py diff --git a/flaml/automl/data.py b/flaml/automl/data.py index 53645a3584..69e075f1e4 100644 --- a/flaml/automl/data.py +++ b/flaml/automl/data.py @@ -373,7 +373,18 @@ def fit_transform(self, X: Union[DataFrame, np.ndarray], y, task: Union[str, "Ta datetime_columns, ) self._drop = drop - if task.is_classification() or not pd.api.types.is_numeric_dtype(y) and not task.is_nlg(): + + # Check if y is multi-target (DataFrame or 2D array with multiple targets) + is_multi_target = False + if isinstance(y, DataFrame) and y.shape[1] > 1: + is_multi_target = True + elif isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1: + is_multi_target = True + + # Skip label encoding for multi-target regression + if is_multi_target and task.is_regression(): + self.label_transformer = None + elif task.is_classification() or not pd.api.types.is_numeric_dtype(y) and not task.is_nlg(): if not task.is_token_classification(): from sklearn.preprocessing import LabelEncoder diff --git a/flaml/automl/task/generic_task.py b/flaml/automl/task/generic_task.py index 6292dcb6d3..81db76f4e3 100644 --- a/flaml/automl/task/generic_task.py +++ b/flaml/automl/task/generic_task.py @@ -788,7 +788,10 @@ def evaluate_model_CV( else: X_train, X_val = X_train_split[train_index], X_train_split[val_index] if not is_spark_dataframe: - y_train, y_val = y_train_split[train_index], y_train_split[val_index] + if isinstance(y_train_split, (pd.DataFrame, pd.Series)): + y_train, y_val = y_train_split.iloc[train_index], y_train_split.iloc[val_index] + else: + y_train, y_val = y_train_split[train_index], y_train_split[val_index] if weight is not None: fit_kwargs["sample_weight"] = ( weight[train_index] if isinstance(weight, np.ndarray) else weight.iloc[train_index] diff --git a/test/automl/test_multi_target.py b/test/automl/test_multi_target.py new file mode 100644 index 0000000000..884410bb64 --- /dev/null +++ b/test/automl/test_multi_target.py @@ -0,0 +1,198 @@ +"""Tests for multi-target regression support in FLAML AutoML.""" +import unittest + +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import make_regression +from sklearn.model_selection import train_test_split + +from flaml import AutoML + + +class TestMultiTargetRegression(unittest.TestCase): + """Test multi-target regression functionality.""" + + def setUp(self): + """Create multi-target regression datasets for testing.""" + # Create synthetic multi-target regression data + self.X, self.y = make_regression( + n_samples=200, n_features=10, n_targets=3, random_state=42, noise=0.1 + ) + self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( + self.X, self.y, test_size=0.2, random_state=42 + ) + + def test_multi_target_with_xgboost(self): + """Test multi-target regression with XGBoost.""" + automl = AutoML() + automl.fit( + self.X_train, + self.y_train, + task="regression", + time_budget=5, + estimator_list=["xgboost"], + verbose=0, + ) + + # Check that the model was trained + self.assertIsNotNone(automl.model) + self.assertEqual(automl.best_estimator, "xgboost") + + # Check predictions shape + y_pred = automl.predict(self.X_test) + self.assertEqual(y_pred.shape, self.y_test.shape) + self.assertEqual(y_pred.ndim, 2) + + def test_multi_target_with_catboost(self): + """Test multi-target regression with CatBoost.""" + try: + import catboost # noqa: F401 + except ImportError: + pytest.skip("CatBoost not installed") + + automl = AutoML() + automl.fit( + self.X_train, + self.y_train, + task="regression", + time_budget=5, + estimator_list=["catboost"], + verbose=0, + ) + + # Check that the model was trained + self.assertIsNotNone(automl.model) + self.assertEqual(automl.best_estimator, "catboost") + + # Check predictions shape + y_pred = automl.predict(self.X_test) + self.assertEqual(y_pred.shape, self.y_test.shape) + self.assertEqual(y_pred.ndim, 2) + + def test_unsupported_estimator_filtered_out(self): + """Test that unsupported estimators are filtered for multi-target.""" + automl = AutoML() + with self.assertRaises(ValueError) as context: + automl.fit( + self.X_train, + self.y_train, + task="regression", + time_budget=5, + estimator_list=["lgbm"], + verbose=0, + ) + self.assertIn("Multi-target regression only supports", str(context.exception)) + + def test_auto_estimator_list(self): + """Test that auto estimator list works with multi-target.""" + automl = AutoML() + automl.fit( + self.X_train, + self.y_train, + task="regression", + time_budget=10, + verbose=0, + ) + + # Check that only supported estimators were used + self.assertIn(automl.best_estimator, ["xgboost", "xgb_limitdepth", "catboost"]) + + # Check predictions shape + y_pred = automl.predict(self.X_test) + self.assertEqual(y_pred.shape, self.y_test.shape) + + def test_multi_target_with_validation_set(self): + """Test multi-target regression with explicit validation set.""" + X_train_sub, X_val, y_train_sub, y_val = train_test_split( + self.X_train, self.y_train, test_size=0.2, random_state=42 + ) + + automl = AutoML() + automl.fit( + X_train_sub, + y_train_sub, + X_val=X_val, + y_val=y_val, + task="regression", + time_budget=5, + estimator_list=["xgboost"], + verbose=0, + ) + + # Check that the model was trained + self.assertIsNotNone(automl.model) + + # Check predictions shape + y_pred = automl.predict(self.X_test) + self.assertEqual(y_pred.shape, self.y_test.shape) + + def test_multi_target_with_dataframe(self): + """Test multi-target regression with pandas DataFrame.""" + X_df = pd.DataFrame(self.X_train, columns=[f"feature_{i}" for i in range(self.X_train.shape[1])]) + y_df = pd.DataFrame(self.y_train, columns=[f"target_{i}" for i in range(self.y_train.shape[1])]) + + automl = AutoML() + automl.fit( + X_df, + y_df, + task="regression", + time_budget=5, + estimator_list=["xgboost"], + verbose=0, + ) + + # Check that the model was trained + self.assertIsNotNone(automl.model) + + # Check predictions shape + X_test_df = pd.DataFrame(self.X_test, columns=[f"feature_{i}" for i in range(self.X_test.shape[1])]) + y_pred = automl.predict(X_test_df) + self.assertEqual(y_pred.shape, self.y_test.shape) + + def test_single_target_still_works(self): + """Test that single-target regression still works correctly.""" + X, y = make_regression(n_samples=200, n_features=10, n_targets=1, random_state=42) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + automl = AutoML() + automl.fit( + X_train, + y_train, + task="regression", + time_budget=5, + estimator_list=["lgbm", "xgboost"], + verbose=0, + ) + + # Check that the model was trained + self.assertIsNotNone(automl.model) + + # Check predictions shape (should be 1D or (n, 1)) + y_pred = automl.predict(X_test) + self.assertEqual(len(y_pred), len(y_test)) + + def test_multi_target_cv(self): + """Test multi-target regression with cross-validation.""" + automl = AutoML() + automl.fit( + self.X_train, + self.y_train, + task="regression", + time_budget=10, + eval_method="cv", + n_splits=3, + estimator_list=["xgboost"], + verbose=0, + ) + + # Check that the model was trained + self.assertIsNotNone(automl.model) + + # Check predictions shape + y_pred = automl.predict(self.X_test) + self.assertEqual(y_pred.shape, self.y_test.shape) + + +if __name__ == "__main__": + unittest.main() From 3e760be864e4b43731111a78704232f9563a555e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 10 Jan 2026 08:51:23 +0000 Subject: [PATCH 4/5] Add demo script for multi-target regression - Added comprehensive demo showing all features of multi-target support - Demonstrates XGBoost and CatBoost support - Shows estimator filtering - Includes DataFrame support examples Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --- examples/multi_target_regression_demo.py | 163 +++++++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 examples/multi_target_regression_demo.py diff --git a/examples/multi_target_regression_demo.py b/examples/multi_target_regression_demo.py new file mode 100644 index 0000000000..8e669ceb31 --- /dev/null +++ b/examples/multi_target_regression_demo.py @@ -0,0 +1,163 @@ +""" +Demo script showing multi-target regression support in FLAML AutoML. + +This script demonstrates: +1. Creating a multi-target regression dataset +2. Training an AutoML model with multi-target support +3. Making predictions with multi-target output +4. Comparing with single-target approach using MultiOutputRegressor wrapper +""" + +import numpy as np +import pandas as pd +from sklearn.datasets import make_regression +from sklearn.metrics import mean_squared_error, r2_score +from sklearn.model_selection import train_test_split + +from flaml import AutoML + +# Create synthetic multi-target regression data +print("=" * 60) +print("Creating Multi-Target Regression Dataset") +print("=" * 60) + +X, y = make_regression( + n_samples=500, + n_features=20, + n_targets=3, # 3 target variables + random_state=42, + noise=0.1, +) + +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +print(f"Training set: X_train shape = {X_train.shape}, y_train shape = {y_train.shape}") +print(f"Test set: X_test shape = {X_test.shape}, y_test shape = {y_test.shape}") +print() + +# Train AutoML with multi-target support +print("=" * 60) +print("Training AutoML with Multi-Target Support") +print("=" * 60) + +automl = AutoML() +automl.fit( + X_train, + y_train, + task="regression", + time_budget=30, # 30 seconds + verbose=0, +) + +print(f"Best estimator: {automl.best_estimator}") +print(f"Best loss: {automl.best_loss:.4f}") +print() + +# Make predictions +print("=" * 60) +print("Making Predictions") +print("=" * 60) + +y_pred = automl.predict(X_test) +print(f"Predictions shape: {y_pred.shape}") +print(f"First 3 predictions:\n{y_pred[:3]}") +print() + +# Evaluate performance +print("=" * 60) +print("Performance Metrics") +print("=" * 60) + +# Overall metrics (averaged across all targets) +mse_overall = mean_squared_error(y_test, y_pred) +r2_overall = r2_score(y_test, y_pred) + +print(f"Overall MSE: {mse_overall:.4f}") +print(f"Overall R²: {r2_overall:.4f}") +print() + +# Per-target metrics +print("Per-Target Metrics:") +for i in range(y_test.shape[1]): + mse_i = mean_squared_error(y_test[:, i], y_pred[:, i]) + r2_i = r2_score(y_test[:, i], y_pred[:, i]) + print(f" Target {i}: MSE = {mse_i:.4f}, R² = {r2_i:.4f}") +print() + +# Compare with pandas DataFrame input +print("=" * 60) +print("Testing with Pandas DataFrame") +print("=" * 60) + +X_df = pd.DataFrame(X_train, columns=[f"feature_{i}" for i in range(X_train.shape[1])]) +y_df = pd.DataFrame(y_train, columns=[f"target_{i}" for i in range(y_train.shape[1])]) + +automl_df = AutoML() +automl_df.fit( + X_df, + y_df, + task="regression", + time_budget=30, + verbose=0, +) + +print(f"Best estimator (DataFrame): {automl_df.best_estimator}") +print(f"Best loss (DataFrame): {automl_df.best_loss:.4f}") +print() + +# Demonstrate filtering of unsupported estimators +print("=" * 60) +print("Demonstrating Estimator Filtering") +print("=" * 60) + +print("Attempting to use LightGBM (unsupported for multi-target)...") +try: + automl_lgbm = AutoML() + automl_lgbm.fit( + X_train, + y_train, + task="regression", + time_budget=5, + estimator_list=["lgbm"], # LightGBM doesn't support multi-target + verbose=0, + ) + print("ERROR: LightGBM should not work with multi-target!") +except ValueError as e: + print(f"✓ Expected error: {e}") +print() + +# Compare supported estimators +print("=" * 60) +print("Comparing Supported Estimators") +print("=" * 60) + +for estimator in ["xgboost", "catboost"]: + try: + print(f"\nTesting {estimator}...") + automl_est = AutoML() + automl_est.fit( + X_train[:200], # Use subset for speed + y_train[:200], + task="regression", + time_budget=10, + estimator_list=[estimator], + verbose=0, + ) + y_pred_est = automl_est.predict(X_test) + mse_est = mean_squared_error(y_test, y_pred_est) + print(f" ✓ {estimator}: MSE = {mse_est:.4f}") + except ImportError: + print(f" ⊗ {estimator}: Not installed") + except Exception as e: + print(f" ✗ {estimator}: Error - {e}") + +print() +print("=" * 60) +print("Demo Complete!") +print("=" * 60) +print("\nKey Takeaways:") +print("1. FLAML now supports multi-target regression natively") +print("2. Only XGBoost and CatBoost are supported for multi-target") +print("3. Works with both numpy arrays and pandas DataFrames") +print("4. Predictions maintain the (n_samples, n_targets) shape") +print("5. Sklearn metrics automatically average across targets") From 0990d8fe17eda0fc67415bbcd51a23bc85688a63 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 21 Jan 2026 07:02:58 +0000 Subject: [PATCH 5/5] Fix formatting with pre-commit - Changed single quotes to double quotes for consistency - Removed trailing whitespace - Fixed line spacing Co-authored-by: thinkall <3197038+thinkall@users.noreply.github.com> --- flaml/automl/automl.py | 4 ++-- flaml/automl/data.py | 4 ++-- flaml/automl/model.py | 6 +++--- flaml/automl/task/generic_task.py | 12 +++++++----- flaml/automl/task/time_series_task.py | 4 +++- test/automl/test_multi_target.py | 4 +--- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index 3d4f8d18db..b374700479 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -829,7 +829,7 @@ def predict( # Only flatten if not multi-target regression if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1: - is_multi_target = getattr(self._state, 'is_multi_target', False) + is_multi_target = getattr(self._state, "is_multi_target", False) if not is_multi_target: y_pred = y_pred.flatten() if self._label_transformer: @@ -2495,7 +2495,7 @@ def is_to_reverse_metric(metric, task): self._state.error_metric = error_metric is_spark_dataframe = isinstance(X_train, psDataFrame) or isinstance(dataframe, psDataFrame) - is_multi_target = getattr(self._state, 'is_multi_target', False) + is_multi_target = getattr(self._state, "is_multi_target", False) estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe, is_multi_target) if is_spark_dataframe and self._use_spark: diff --git a/flaml/automl/data.py b/flaml/automl/data.py index 69e075f1e4..c7e5bf133b 100644 --- a/flaml/automl/data.py +++ b/flaml/automl/data.py @@ -373,14 +373,14 @@ def fit_transform(self, X: Union[DataFrame, np.ndarray], y, task: Union[str, "Ta datetime_columns, ) self._drop = drop - + # Check if y is multi-target (DataFrame or 2D array with multiple targets) is_multi_target = False if isinstance(y, DataFrame) and y.shape[1] > 1: is_multi_target = True elif isinstance(y, np.ndarray) and y.ndim == 2 and y.shape[1] > 1: is_multi_target = True - + # Skip label encoding for multi-target regression if is_multi_target and task.is_regression(): self.label_transformer = None diff --git a/flaml/automl/model.py b/flaml/automl/model.py index 8318c13c52..b082d71e9c 100644 --- a/flaml/automl/model.py +++ b/flaml/automl/model.py @@ -2112,7 +2112,7 @@ def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs): cat_features = list(X_train.select_dtypes(include="category").columns) else: cat_features = [] - + # Detect multi-target regression and set appropriate loss function is_multi_target = False if self._task.is_regression(): @@ -2120,10 +2120,10 @@ def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs): is_multi_target = True elif isinstance(y_train, DataFrame) and y_train.shape[1] > 1: is_multi_target = True - + if is_multi_target and "loss_function" not in self.params: self.params["loss_function"] = "MultiRMSE" - + use_best_model = kwargs.get("use_best_model", True) n = max(int(len(y_train) * 0.9), len(y_train) - 1000) if use_best_model else len(y_train) X_tr, y_tr = X_train[:n], y_train[:n] diff --git a/flaml/automl/task/generic_task.py b/flaml/automl/task/generic_task.py index b36771c950..44f9fb5168 100644 --- a/flaml/automl/task/generic_task.py +++ b/flaml/automl/task/generic_task.py @@ -231,7 +231,7 @@ def validate_data( elif isinstance(automl._y_train_all, pd.DataFrame): is_multi_target = True n_targets = automl._y_train_all.shape[1] - + state.is_multi_target = is_multi_target state.n_targets = n_targets @@ -1287,7 +1287,9 @@ def evaluate_model_CV( pred_time /= n return val_loss, metric, train_time, pred_time - def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool = False, is_multi_target: bool = False) -> List[str]: + def default_estimator_list( + self, estimator_list: List[str], is_spark_dataframe: bool = False, is_multi_target: bool = False + ) -> List[str]: if "auto" != estimator_list: n_estimators = len(estimator_list) if is_spark_dataframe: @@ -1316,7 +1318,7 @@ def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: "Non-spark dataframes only support estimator names not ending with `_spark`. Non-supported " "estimators are removed." ) - + # Filter out unsupported estimators for multi-target regression if is_multi_target and self.is_regression(): # List of estimators that support multi-target regression natively @@ -1382,7 +1384,7 @@ def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: for est in estimator_list if (est.endswith("_spark") if is_spark_dataframe else not est.endswith("_spark")) ] - + # Filter for multi-target regression support if is_multi_target and self.is_regression(): # List of estimators that support multi-target regression natively @@ -1393,7 +1395,7 @@ def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: "Multi-target regression only supports estimators: xgboost, xgb_limitdepth, catboost. " "No supported estimator is available." ) - + return estimator_list def default_metric(self, metric: str) -> str: diff --git a/flaml/automl/task/time_series_task.py b/flaml/automl/task/time_series_task.py index 7755cb9b4b..4e103752d7 100644 --- a/flaml/automl/task/time_series_task.py +++ b/flaml/automl/task/time_series_task.py @@ -458,7 +458,9 @@ def evaluate_model_CV( pred_time /= n return val_loss, metric, train_time, pred_time - def default_estimator_list(self, estimator_list: List[str], is_spark_dataframe: bool, is_multi_target: bool = False) -> List[str]: + def default_estimator_list( + self, estimator_list: List[str], is_spark_dataframe: bool, is_multi_target: bool = False + ) -> List[str]: assert not is_spark_dataframe, "Spark is not yet supported for time series" # TODO: why not do this if/then in the calling function? diff --git a/test/automl/test_multi_target.py b/test/automl/test_multi_target.py index 884410bb64..c359199373 100644 --- a/test/automl/test_multi_target.py +++ b/test/automl/test_multi_target.py @@ -16,9 +16,7 @@ class TestMultiTargetRegression(unittest.TestCase): def setUp(self): """Create multi-target regression datasets for testing.""" # Create synthetic multi-target regression data - self.X, self.y = make_regression( - n_samples=200, n_features=10, n_targets=3, random_state=42, noise=0.1 - ) + self.X, self.y = make_regression(n_samples=200, n_features=10, n_targets=3, random_state=42, noise=0.1) self.X_train, self.X_test, self.y_train, self.y_test = train_test_split( self.X, self.y, test_size=0.2, random_state=42 )