diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_beta/_beta_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_beta/_beta_distribution_loss_pkg.py index 1e146f678..3abf36659 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_beta/_beta_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_beta/_beta_distribution_loss_pkg.py @@ -3,6 +3,7 @@ """ from pytorch_forecasting.data import TorchNormalizer +from pytorch_forecasting.data.encoders import GroupNormalizer from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric @@ -16,8 +17,23 @@ class BetaDistributionLoss_pkg(_BasePtMetric): "distribution_type": "beta", "info:metric_name": "BetaDistributionLoss", "requires:data_type": "beta_distribution_forecast", + "info:pred_type": ["distr"], + "info:y_type": ["numeric"], + "loss_ndim": 2, } + @property + def clip_target(cls): + return True + + @property + def data_loader_kwargs(cls): + return { + "target_normalizer": GroupNormalizer( + groups=["agency", "sku"], transformation="logit" + ) + } + @classmethod def get_cls(cls): from pytorch_forecasting.metrics.distributions import BetaDistributionLoss @@ -30,3 +46,14 @@ def get_encoder(cls): Returns a TorchNormalizer instance for rescaling parameters. """ return TorchNormalizer(transformation="logit") + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for BetaDistributionLoss. + """ + kwargs = dict(target="agency") + kwargs.update(cls.data_loader_kwargs) + return super()._get_test_dataloaders_from( + params, clip_target=cls.clip_target, **kwargs + ) diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_implicit_quantile_network/_implicit_quantile_network_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_implicit_quantile_network/_implicit_quantile_network_distribution_loss_pkg.py index d3589154b..ec8aa4798 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_implicit_quantile_network/_implicit_quantile_network_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_implicit_quantile_network/_implicit_quantile_network_distribution_loss_pkg.py @@ -18,6 +18,8 @@ class ImplicitQuantileNetworkDistributionLoss_pkg(_BasePtMetric): "requires:data_type": "implicit_quantile_network_distribution_forecast", "capability:quantile_generation": True, "shape:adds_quantile_dimension": True, + "info:pred_type": ["distr"], + "info:y_type": ["numeric"], } @classmethod @@ -44,3 +46,10 @@ def get_metric_test_params(cls): fixture for testing the ImplicitQuantileNetworkDistributionLoss metric. """ return [{"input_size": 5}] + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for ImplicitQuantileNetworkDistributionLoss. + """ + return super()._get_test_dataloaders_from(params) diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_log_normal/_log_normal_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_log_normal/_log_normal_distribution_loss_pkg.py index a17d7f862..b48e6b809 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_log_normal/_log_normal_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_log_normal/_log_normal_distribution_loss_pkg.py @@ -5,6 +5,7 @@ import torch from pytorch_forecasting.data import TorchNormalizer +from pytorch_forecasting.data.encoders import GroupNormalizer from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric @@ -18,8 +19,23 @@ class LogNormalDistributionLoss_pkg(_BasePtMetric): "distribution_type": "log_normal", "info:metric_name": "LogNormalDistributionLoss", "requires:data_type": "log_normal_distribution_forecast", + "info:pred_type": ["distr"], + "info:y_type": ["numeric"], + "loss_ndim": 2, } + @property + def clip_target(self): + return True + + @property + def data_loader_kwargs(self): + return { + "target_normalizer": GroupNormalizer( + groups=["agency", "sku"], transformation="log1p" + ) + } + @classmethod def get_cls(cls): from pytorch_forecasting.metrics.distributions import LogNormalDistributionLoss @@ -48,3 +64,14 @@ def prepare_test_inputs(cls, test_case): ) return y_pred, y + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for LogNormalDistributionLoss. + """ + kwargs = dict(target="agency") + kwargs.update(cls.data_loader_kwargs) + return super()._get_test_dataloaders_from( + params, clip_target=cls.clip_target, **kwargs + ) diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_mqf2/_mqf2_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_mqf2/_mqf2_distribution_loss_pkg.py index ed925ccfe..23b921475 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_mqf2/_mqf2_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_mqf2/_mqf2_distribution_loss_pkg.py @@ -3,6 +3,7 @@ """ from pytorch_forecasting.data import TorchNormalizer +from pytorch_forecasting.data.encoders import GroupNormalizer from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric @@ -20,6 +21,22 @@ class MQF2DistributionLoss_pkg(_BasePtMetric): "requires:data_type": "mqf2_distribution_forecast", } + @property + def clip_target(self): + return True + + @property + def data_loader_kwargs(self): + return { + "target_normalizer": GroupNormalizer( + groups=["agency", "sku"], center=False, transformation="log1p" + ) + } + + @property + def trainer_kwargs(self): + return dict(accelerator="cpu") + @classmethod def get_cls(cls): from pytorch_forecasting.metrics.distributions import MQF2DistributionLoss diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_multivariate_normal/_multivariate_normal_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_multivariate_normal/_multivariate_normal_distribution_loss_pkg.py index 9e0db69f6..85625ae92 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_multivariate_normal/_multivariate_normal_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_multivariate_normal/_multivariate_normal_distribution_loss_pkg.py @@ -2,6 +2,7 @@ Package container for multivariate normal distribution loss metric. """ +from pytorch_forecasting.data.encoders import GroupNormalizer from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric @@ -17,8 +18,23 @@ class MultivariateNormalDistributionLoss_pkg(_BasePtMetric): "distribution_type": "multivariate_normal", "info:metric_name": "MultivariateNormalDistributionLoss", "requires:data_type": "multivariate_normal_distribution_forecast", + "info:pred_type": ["distr"], + "info:y_type": ["numeric"], + "loss_ndim": 2, } + @property + def clip_target(self): + return False + + @property + def data_loader_kwargs(self): + return { + "target_normalizer": GroupNormalizer( + groups=["agency", "sku"], transformation="log1p" + ) + } + @classmethod def get_cls(cls): from pytorch_forecasting.metrics.distributions import ( @@ -26,3 +42,12 @@ def get_cls(cls): ) return MultivariateNormalDistributionLoss + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for MultivariateNormalDistributionLoss. + """ + kwargs = dict(target="agency") + kwargs.update(cls.data_loader_kwargs) + return super()._get_test_dataloaders_from(params, **kwargs) diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_negative_binomial/_negative_binomial_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_negative_binomial/_negative_binomial_distribution_loss_pkg.py index 12f2ef65a..e417d7560 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_negative_binomial/_negative_binomial_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_negative_binomial/_negative_binomial_distribution_loss_pkg.py @@ -3,6 +3,7 @@ """ from pytorch_forecasting.data import TorchNormalizer +from pytorch_forecasting.data.encoders import GroupNormalizer from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric @@ -16,8 +17,21 @@ class NegativeBinomialDistributionLoss_pkg(_BasePtMetric): "distribution_type": "negative_binomial", "info:metric_name": "NegativeBinomialDistributionLoss", "requires:data_type": "negative_binomial_distribution_forecast", + "info:pred_type": ["distr"], + "info:y_type": ["numeric"], + "loss_ndim": 2, } + @property + def clip_target(self): + return False + + @property + def data_loader_kwargs(self): + return { + "target_normalizer": GroupNormalizer(groups=["agency", "sku"], center=False) + } + @classmethod def get_cls(cls): from pytorch_forecasting.metrics.distributions import ( @@ -32,3 +46,14 @@ def get_encoder(cls): Returns a TorchNormalizer instance for rescaling parameters. """ return TorchNormalizer(center=False) + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for NegativeBinomialDistributionLoss. + """ + kwargs = dict(target="agency") + kwargs.update(cls.data_loader_kwargs) + return super()._get_test_dataloaders_from( + params, clip_target=cls.clip_target, **kwargs + ) diff --git a/pytorch_forecasting/metrics/_distributions_pkg/_normal/_normal_distribution_loss_pkg.py b/pytorch_forecasting/metrics/_distributions_pkg/_normal/_normal_distribution_loss_pkg.py index 653eddfdd..afff845ac 100644 --- a/pytorch_forecasting/metrics/_distributions_pkg/_normal/_normal_distribution_loss_pkg.py +++ b/pytorch_forecasting/metrics/_distributions_pkg/_normal/_normal_distribution_loss_pkg.py @@ -17,6 +17,9 @@ class NormalDistributionLoss_pkg(_BasePtMetric): "distribution_type": "normal", "info:metric_name": "NormalDistributionLoss", "requires:data_type": "normal_distribution_forecast", + "info:pred_type": ["distr"], + "info:y_type": ["numeric"], + "loss_ndim": 2, } @classmethod @@ -24,3 +27,10 @@ def get_cls(cls): from pytorch_forecasting.metrics.distributions import NormalDistributionLoss return NormalDistributionLoss + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for NormalDistributionLoss. + """ + super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_cross_entropy/_cross_entropy_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_cross_entropy/_cross_entropy_pkg.py index 03ae9e647..60d9f84f2 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_cross_entropy/_cross_entropy_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_cross_entropy/_cross_entropy_pkg.py @@ -18,6 +18,9 @@ class CrossEntropy_pkg(_BasePtMetric): "requires:data_type": "classification_forecast", "info:metric_name": "CrossEntropy", "no_rescaling": True, + "info:pred_type": ["point"], + "info:y_type": ["category"], + "loss_ndim": 1, } @classmethod @@ -25,3 +28,10 @@ def get_cls(cls): from pytorch_forecasting.metrics import CrossEntropy return CrossEntropy + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for CrossEntropy. + """ + super()._get_test_dataloaders_from(params=params, target="category") diff --git a/pytorch_forecasting/metrics/_point_pkg/_mae/_mae_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_mae/_mae_pkg.py index f2db78f80..15632ce60 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_mae/_mae_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_mae/_mae_pkg.py @@ -16,6 +16,9 @@ class MAE_pkg(_BasePtMetric): "metric_type": "point", "requires:data_type": "point_forecast", "info:metric_name": "MAE", + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "loss_ndim": 1, } @classmethod @@ -23,3 +26,10 @@ def get_cls(cls): from pytorch_forecasting.metrics import MAE return MAE + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for MAE. + """ + return super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_mape/_mape_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_mape/_mape_pkg.py index db9051c75..1db954132 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_mape/_mape_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_mape/_mape_pkg.py @@ -18,6 +18,9 @@ class MAPE_pkg(_BasePtMetric): "metric_type": "point", "info:metric_name": "MAPE", "requires:data_type": "point_forecast", + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "loss_ndim": 1, } @classmethod @@ -25,3 +28,10 @@ def get_cls(cls): from pytorch_forecasting.metrics.point import MAPE return MAPE + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for MAPE. + """ + return super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_mase/_mase_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_mase/_mase_pkg.py index f128b125d..ea7afe8e6 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_mase/_mase_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_mase/_mase_pkg.py @@ -14,6 +14,9 @@ class MASE_pkg(_BasePtMetric): "metric_type": "point", "info:metric_name": "MASE", "requires:data_type": "point_forecast", + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "loss_ndim": 1, } @classmethod @@ -21,3 +24,10 @@ def get_cls(cls): from pytorch_forecasting.metrics import MASE return MASE + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for MASE. + """ + return super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_poisson/_poisson_loss_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_poisson/_poisson_loss_pkg.py index 6ac1c3338..366c30ba6 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_poisson/_poisson_loss_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_poisson/_poisson_loss_pkg.py @@ -18,6 +18,9 @@ class PoissonLoss_pkg(_BasePtMetric): "requires:data_type": "point_forecast", "capability:quantile_generation": True, "shape:adds_quantile_dimension": True, + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "loss_ndim": 1, } @classmethod @@ -25,3 +28,10 @@ def get_cls(cls): from pytorch_forecasting.metrics.point import PoissonLoss return PoissonLoss + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for PoissonLoss. + """ + return super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_rmse/_rmse_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_rmse/_rmse_pkg.py index d9db301eb..a7c7509b9 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_rmse/_rmse_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_rmse/_rmse_pkg.py @@ -16,6 +16,9 @@ class RMSE_pkg(_BasePtMetric): "metric_type": "point", "info:metric_name": "RMSE", "requires:data_type": "point_forecast", + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "loss_ndim": 1, } # noqa: E501 @classmethod @@ -23,3 +26,10 @@ def get_cls(cls): from pytorch_forecasting.metrics.point import RMSE return RMSE + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for RMSE. + """ + return super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_smape/_smape_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_smape/_smape_pkg.py index 00e0f3d13..b192b74c7 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_smape/_smape_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_smape/_smape_pkg.py @@ -18,6 +18,9 @@ class SMAPE_pkg(_BasePtMetric): "metric_type": "point", "info:metric_name": "SMAPE", "requires:data_type": "point_forecast", + "info:pred_type": ["point"], + "info:y_type": ["numeric"], + "loss_ndim": 1, } # noqa: E501 @classmethod @@ -25,3 +28,10 @@ def get_cls(cls): from pytorch_forecasting.metrics.point import SMAPE return SMAPE + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for SMAPE. + """ + return super()._get_test_dataloaders_from(params=params, target="agency") diff --git a/pytorch_forecasting/metrics/_point_pkg/_tweedie/_tweedie_loss_pkg.py b/pytorch_forecasting/metrics/_point_pkg/_tweedie/_tweedie_loss_pkg.py index 07250ff33..d6ccbed70 100644 --- a/pytorch_forecasting/metrics/_point_pkg/_tweedie/_tweedie_loss_pkg.py +++ b/pytorch_forecasting/metrics/_point_pkg/_tweedie/_tweedie_loss_pkg.py @@ -16,6 +16,9 @@ class TweedieLoss_pkg(_BasePtMetric): "metric_type": "point", "info:metric_name": "TweedieLoss", "requires:data_type": "point_forecast", + "info:pred_type": ["point"], + "info:y_types": ["numeric"], + "loss_ndim": 1, } # noqa: E501 @classmethod @@ -23,3 +26,10 @@ def get_cls(cls): from pytorch_forecasting.metrics.point import TweedieLoss return TweedieLoss + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for TweedieLoss. + """ + return super()._get_test_dataloaders_from(params, target="agency") diff --git a/pytorch_forecasting/metrics/_quantile_pkg/_quantile_loss_pkg.py b/pytorch_forecasting/metrics/_quantile_pkg/_quantile_loss_pkg.py index fd3e27c14..50751daf9 100644 --- a/pytorch_forecasting/metrics/_quantile_pkg/_quantile_loss_pkg.py +++ b/pytorch_forecasting/metrics/_quantile_pkg/_quantile_loss_pkg.py @@ -16,6 +16,9 @@ class QuantileLoss_pkg(_BasePtMetric): "metric_type": "quantile", "info:metric_name": "QuantileLoss", "requires:data_type": "quantile_forecast", + "info:pred_type": ["quantile"], + "info:y_type": ["numeric"], + "loss_ndim": "num_quantiles", } # noqa: E501 @classmethod @@ -34,3 +37,10 @@ def get_metric_test_params(cls): "quantiles": [0.2, 0.5], }, ] + + @classmethod + def _get_test_dataloaders(cls, params=None): + """ + Returns test dataloaders configured for QuantileLoss. + """ + return super()._get_test_dataloaders_from(params, target="agency") diff --git a/pytorch_forecasting/metrics/base_metrics/_base_object.py b/pytorch_forecasting/metrics/base_metrics/_base_object.py index d695ed087..b6447da55 100644 --- a/pytorch_forecasting/metrics/base_metrics/_base_object.py +++ b/pytorch_forecasting/metrics/base_metrics/_base_object.py @@ -1,6 +1,10 @@ """Base object class for pytorch-forecasting metrics.""" from pytorch_forecasting.base._base_object import _BaseObject +from pytorch_forecasting.tests._data_scenarios import ( + data_with_covariates, + make_dataloaders, +) class _BasePtMetric(_BaseObject): @@ -78,3 +82,22 @@ def get_encoder(cls): from pytorch_forecasting.data import TorchNormalizer return TorchNormalizer() + + @classmethod + def _get_test_dataloaders_from(cls, params, **kwargs): + """ + Returns test dataloaders configured for the metric. + Child classes can override or pass kwargs for customization. + """ + if params is None: + params = {} + data_loader_kwargs = {} + data_loader_kwargs.update(params.get("data_loader_kwargs", {})) + data_loader_kwargs.update(kwargs) + clip_target = params.get("clip_target", False) + + data = data_with_covariates() + if clip_target: + data["target"] = data["target"].clip(1e-4, 1 - 1e-4) + dataloaders = make_dataloaders(data, **data_loader_kwargs) + return dataloaders diff --git a/pytorch_forecasting/models/nhits/_nhits_pkg.py b/pytorch_forecasting/models/nhits/_nhits_pkg.py index 359b7f349..710be45a2 100644 --- a/pytorch_forecasting/models/nhits/_nhits_pkg.py +++ b/pytorch_forecasting/models/nhits/_nhits_pkg.py @@ -65,38 +65,31 @@ def _get_test_dataloaders_from(cls, params): Dict of dataloaders created from the parameters. Train, validation, and test dataloaders, in this order. """ - loss = params.get("loss", None) data_loader_kwargs = params.get("data_loader_kwargs", {}) clip_target = params.get("clip_target", False) - import inspect - from pytorch_forecasting.metrics import ( - LogNormalDistributionLoss, MQF2DistributionLoss, - MultivariateNormalDistributionLoss, - NegativeBinomialDistributionLoss, - TweedieLoss, ) from pytorch_forecasting.tests._data_scenarios import ( data_with_covariates, dataloaders_fixed_window_without_covariates, make_dataloaders, ) - from pytorch_forecasting.tests._loss_mapping import DISTR_LOSSES_NUMERIC - - distr_losses = tuple( - type(l) - for l in DISTR_LOSSES_NUMERIC - if not isinstance(l, MultivariateNormalDistributionLoss) - # use dataloaders without covariates as default settings of nhits - # (hidden_size = 512) is not compatible with - # MultivariateNormalDistributionLoss causing Cholesky - # decomposition to fail during loss computation. - ) - if isinstance(loss, distr_losses): + # Use fixed window dataloaders for MultivariateNormalDistributionLoss + if hasattr( + loss, "get_class_tag" + ) and "multivariate_normal" in loss.get_class_tag("distribution_type", ""): + return dataloaders_fixed_window_without_covariates() + + # For other distribution losses, use covariates and apply preprocessing + distr_types = {"log_normal", "negative_binomial", "mqf2", "beta"} + if ( + hasattr(loss, "get_class_tag") + and loss.get_class_tag("distribution_type", "") in distr_types + ): dwc = data_with_covariates() if clip_target: dwc["target"] = dwc["volume"].clip(1e-3, 1.0) @@ -109,19 +102,16 @@ def _get_test_dataloaders_from(cls, params): ) dl_default_kwargs.update(data_loader_kwargs) - if isinstance(loss, NegativeBinomialDistributionLoss): + if loss.get_class_tag("distribution_type", "") == "negative_binomial": dwc = dwc.assign(volume=lambda x: x.volume.round()) - # todo: still need some debugging to add the MQF2DistributionLoss - # elif inspect.isclass(loss) and issubclass(loss, MQF2DistributionLoss): - # dwc = dwc.assign(volume=lambda x: x.volume.round()) - # data_loader_kwargs["target"] = "volume" - # data_loader_kwargs["time_varying_unknown_reals"] = ["volume"] - elif isinstance(loss, LogNormalDistributionLoss): + elif loss.get_class_tag("distribution_type", "") == "log_normal": dwc["volume"] = dwc["volume"].clip(1e-3, 1.0) - dataloaders_with_covariates = make_dataloaders(dwc, **dl_default_kwargs) - return dataloaders_with_covariates + return make_dataloaders(dwc, **dl_default_kwargs) - if isinstance(loss, TweedieLoss): + if ( + hasattr(loss, "get_class_tag") + and loss.get_class_tag("info:metric_name", "") == "TweedieLoss" + ): dwc = data_with_covariates() dl_default_kwargs = dict( target="target", @@ -129,7 +119,6 @@ def _get_test_dataloaders_from(cls, params): add_relative_time_idx=False, ) dl_default_kwargs.update(data_loader_kwargs) - dataloaders_with_covariates = make_dataloaders(dwc, **dl_default_kwargs) - return dataloaders_with_covariates + return make_dataloaders(dwc, **dl_default_kwargs) return dataloaders_fixed_window_without_covariates() diff --git a/pytorch_forecasting/tests/_loss_mapping.py b/pytorch_forecasting/tests/_loss_mapping.py index d2b41fc3e..4664f40b4 100644 --- a/pytorch_forecasting/tests/_loss_mapping.py +++ b/pytorch_forecasting/tests/_loss_mapping.py @@ -1,124 +1,22 @@ -from pytorch_forecasting.data.encoders import GroupNormalizer -from pytorch_forecasting.metrics import ( - MAE, - MAPE, - MASE, - RMSE, - SMAPE, - BetaDistributionLoss, - CrossEntropy, - ImplicitQuantileNetworkDistributionLoss, - LogNormalDistributionLoss, - MQF2DistributionLoss, - MultivariateNormalDistributionLoss, - NegativeBinomialDistributionLoss, - NormalDistributionLoss, - PoissonLoss, - QuantileLoss, - TweedieLoss, -) +from pytorch_forecasting._registry import all_objects -POINT_LOSSES_NUMERIC = [ - MAE(), - RMSE(), - SMAPE(), - MAPE(), - PoissonLoss(), - MASE(), - TweedieLoss(), -] +# Remove legacy lists and mappings for losses by pred/y type and tensor shape checks. +# Use tags and _get_test_dataloaders_from for all compatibility and test setup. -POINT_LOSSES_CATEGORY = [ - CrossEntropy(), -] - -QUANTILE_LOSSES_NUMERIC = [ - QuantileLoss(), -] - -DISTR_LOSSES_NUMERIC = [ - NormalDistributionLoss(), - NegativeBinomialDistributionLoss(), - MultivariateNormalDistributionLoss(), - LogNormalDistributionLoss(), - BetaDistributionLoss(), - ImplicitQuantileNetworkDistributionLoss(), - # todo: still need some debugging to add the MQF2DistributionLoss -] - -LOSSES_BY_PRED_AND_Y_TYPE = { - ("point", "numeric"): POINT_LOSSES_NUMERIC, - ("point", "category"): POINT_LOSSES_CATEGORY, - ("quantile", "numeric"): QUANTILE_LOSSES_NUMERIC, - ("quantile", "category"): [], - ("distr", "numeric"): DISTR_LOSSES_NUMERIC, - ("distr", "category"): [], -} - - -LOSS_SPECIFIC_PARAMS = { - "BetaDistributionLoss": { - "clip_target": True, - "data_loader_kwargs": { - "target_normalizer": GroupNormalizer( - groups=["agency", "sku"], transformation="logit" - ) - }, - }, - "LogNormalDistributionLoss": { - "clip_target": True, - "data_loader_kwargs": { - "target_normalizer": GroupNormalizer( - groups=["agency", "sku"], transformation="log1p" - ) - }, - }, - "NegativeBinomialDistributionLoss": { - "clip_target": False, - "data_loader_kwargs": { - "target_normalizer": GroupNormalizer(groups=["agency", "sku"], center=False) - }, - }, - "MultivariateNormalDistributionLoss": { - "data_loader_kwargs": { - "target_normalizer": GroupNormalizer( - groups=["agency", "sku"], transformation="log1p" - ) - }, - }, - "MQF2DistributionLoss": { - "clip_target": True, - "data_loader_kwargs": { - "target_normalizer": GroupNormalizer( - groups=["agency", "sku"], center=False, transformation="log1p" - ) - }, - "trainer_kwargs": dict(accelerator="cpu"), - }, -} +METRIC_PKGS = all_objects(object_types="metric", return_names=False) def get_compatible_losses(pred_types, y_types): - """Get compatible losses based on prediction types and target types. - - Parameters - ---------- - pred_types : list of str - Prediction types, e.g., ["point", "distr"] - y_types : list of str - Target types, e.g., ["numeric", "category"] - - Returns - ------- - list - List of compatible loss instances + """ + Get compatible losses based on prediction types and target types. + Returns a list of (pkg, loss_instance) tuples. """ compatible_losses = [] - - for pred_type in pred_types: - for y_type in y_types: - key = (pred_type, y_type) - if key in LOSSES_BY_PRED_AND_Y_TYPE: - compatible_losses.extend(LOSSES_BY_PRED_AND_Y_TYPE[key]) - + for pkg in METRIC_PKGS: + pkg_pred_types = pkg.get_class_tag("info:pred_type", []) + pkg_y_types = pkg.get_class_tag("info:y_type", []) + if any(pt in pred_types for pt in pkg_pred_types) and any( + yt in y_types for yt in pkg_y_types + ): + compatible_losses.append((pkg, pkg.get_cls()())) return compatible_losses diff --git a/pytorch_forecasting/tests/test_all_estimators.py b/pytorch_forecasting/tests/test_all_estimators.py index d8eb7d81e..a6d0e811e 100644 --- a/pytorch_forecasting/tests/test_all_estimators.py +++ b/pytorch_forecasting/tests/test_all_estimators.py @@ -12,8 +12,6 @@ from pytorch_forecasting.tests._base._fixture_generator import BaseFixtureGenerator from pytorch_forecasting.tests._config import EXCLUDE_ESTIMATORS, EXCLUDED_TESTS from pytorch_forecasting.tests._loss_mapping import ( - LOSS_SPECIFIC_PARAMS, - LOSSES_BY_PRED_AND_Y_TYPE, get_compatible_losses, ) @@ -196,16 +194,16 @@ def _generate_final_param_list(self, compatible_losses, base_params_list): """ all_train_kwargs = [] train_kwargs_names = [] - for loss_item in compatible_losses: - if inspect.isclass(loss_item): - loss_name = loss_item.__name__ - loss = loss_item - else: - loss_name = loss_item.__class__.__name__ - loss = loss_item - loss_params = deepcopy(LOSS_SPECIFIC_PARAMS.get(loss_name, {})) - loss_params["loss"] = loss - + for pkg_cls, loss in compatible_losses: + loss_name = loss.__class__.__name__ + pkg_instance = pkg_cls() + clip_target = getattr(pkg_instance, "clip_target", False) + data_loader_kwargs = getattr(pkg_instance, "data_loader_kwargs", {}) + loss_params = { + "clip_target": clip_target, + "data_loader_kwargs": data_loader_kwargs, + "loss": loss, + } for i, base_params in enumerate(base_params_list): final_params = _nested_update(base_params, loss_params) all_train_kwargs.append(final_params)