Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kolmogorov Arnold Block for NBeats #1751

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 113 additions & 40 deletions pytorch_forecasting/models/nbeats/_nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def __init__(
expansion_coefficient_lengths: Optional[List[int]] = None,
prediction_length: int = 1,
context_length: int = 1,
use_kan: bool = False,
benHeid marked this conversation as resolved.
Show resolved Hide resolved
num_grids: int = 5,
k: int = 3,
noise_scale: float = 0.5,
scale_base_mu: float = 0.0,
scale_base_sigma: float = 1.0,
scale_sp: float = 1.0,
base_fun: callable = torch.nn.SiLU(),
benHeid marked this conversation as resolved.
Show resolved Hide resolved
grid_eps: float = 0.02,
grid_range: List[int] = [-1, 1],
benHeid marked this conversation as resolved.
Show resolved Hide resolved
sp_trainable: bool = True,
sb_trainable: bool = True,
sparse_init: bool = False,
dropout: float = 0.1,
learning_rate: float = 1e-2,
log_interval: int = -1,
Expand All @@ -47,48 +60,86 @@ def __init__(

Based on the article
`N-BEATS: Neural basis expansion analysis for interpretable time series
forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if used as ensemble) outperformed all
other methods
including ensembles of traditional statical methods in the M4 competition. The M4 competition is arguably
the most
important benchmark for univariate time series forecasting.
forecasting <http://arxiv.org/abs/1905.10437>`_. The network has (if
benHeid marked this conversation as resolved.
Show resolved Hide resolved
used as ensemble) outperformed all other methods including ensembles of
traditional statical methods in the M4 competition. The M4 competition is
arguably the most important benchmark for univariate time series forecasting.

The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently shown to consistently outperform
N-BEATS.
The :py:class:`~pytorch_forecasting.models.nhits.NHiTS` network has recently
shown to consistently outperform N-BEATS.

Args:
stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings
of length 1 or ‘num_stacks’. Default and recommended value
for generic mode: [“generic”] Recommended value for interpretable mode: [“trend”,”seasonality”]
num_blocks: The number of blocks per stack. A list of ints of length 1 or ‘num_stacks’.
Default and recommended value for generic mode: [1] Recommended value for interpretable mode: [3]
num_block_layers: Number of fully connected layers with ReLu activation per block. A list of ints of length
1 or ‘num_stacks’.
Default and recommended value for generic mode: [4] Recommended value for interpretable mode: [4]
width: Widths of the fully connected layers with ReLu activation in the blocks.
A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [512]
Recommended value for interpretable mode: [256, 2048]
stack_types: One of the following values: “generic”, “seasonality" or
“trend". A list of strings of length 1 or ‘num_stacks’. Default and
recommended value for generic mode: [“generic”] Recommended value for
interpretable mode: [“trend”,”seasonality”].
num_blocks: The number of blocks per stack. A list of ints of length 1 or
‘num_stacks’. Default and recommended value for generic mode: [1]
Recommended value for interpretable mode: [3]
num_block_layers: Number of fully connected layers with ReLu activation per
block.
A list of ints of length 1 or ‘num_stacks’. Default and recommended
value for generic mode: [4] Recommended value for interpretable mode:
[4].
width: Widths of the fully connected layers with ReLu activation in the
blocks. A list of ints of length 1 or ‘num_stacks’. Default and
recommended value for generic mode: [512]. Recommended value for
interpretable mode: [256, 2048]
sharing: Whether the weights are shared with the other blocks per stack.
A list of ints of length 1 or ‘num_stacks’. Default and recommended value for generic mode: [False]
Recommended value for interpretable mode: [True]
expansion_coefficient_length: If the type is “G” (generic), then the length of the expansion
coefficient.
If type is “T” (trend), then it corresponds to the degree of the polynomial. If the type is “S”
(seasonal) then this is the minimum period allowed, e.g. 2 for changes every timestep.
A list of ints of length 1 or ‘num_stacks’. Default value for generic mode: [32] Recommended value for
A list of ints of length 1 or ‘num_stacks’. Default and recommended
value for generic mode: [False]. Recommended value for interpretable
mode: [True].
expansion_coefficient_length: If the type is “G” (generic), then the length
of the expansion coefficient.
If type is “T” (trend), then it corresponds to the degree of the
polynomial.
If the type is “S” (seasonal) then this is the minimum period allowed,
e.g. 2 for changes every timestep. A list of ints of length 1 or
‘num_stacks’. Default value for generic mode: [32] Recommended value for
interpretable mode: [3]
prediction_length: Length of the prediction. Also known as 'horizon'.
context_length: Number of time units that condition the predictions. Also known as 'lookback period'.
context_length: Number of time units that condition the predictions.
Also known as 'lookback period'.
Should be between 1-10 times the prediction length.
backcast_loss_ratio: weight of backcast in comparison to forecast when calculating the loss.
A weight of 1.0 means that forecast and backcast loss is weighted the same (regardless of backcast and
forecast lengths). Defaults to 0.0, i.e. no weight.
num_grids : Parameter for KAN layer. the number of grid intervals = G.
benHeid marked this conversation as resolved.
Show resolved Hide resolved
benHeid marked this conversation as resolved.
Show resolved Hide resolved
Default: 5.
k : Parameter for KAN layer. the order of piecewise polynomial. Default: 3.
noise_scale : Parameter for KAN layer. the scale of noise injected at
initialization. Default: 0.1.
scale_base_mu : Parameter for KAN layer. the scale of the residual
function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
Deafult: 0.0
scale_base_sigma : Parameter for KAN layer. the scale of the residual
function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
Deafult: 1.0
scale_sp : Parameter for KAN layer. the scale of the base function
spline(x). Deafult: 1.0
base_fun : Parameter for KAN layer. residual function b(x).
Default: torch.nn.SiLU()
grid_eps : Parameter for KAN layer. When grid_eps = 1, the grid is uniform;
when grid_eps = 0, the grid is partitioned using percentiles of samples.
0 < grid_eps < 1 interpolates between the two extremes. Deafult: 0.02
grid_range : Parameter for KAN layer. list/np.array of shape (2,). setting
the range of grids.
Default: [-1,1].
sp_trainable : Parameter for KAN layer. If true, scale_sp is trainable.
Default: True.
sb_trainable : Parameter for KAN layer. If true, scale_base is trainable.
Default: True.
sparse_init : Parameter for KAN layer. if sparse_init = True, sparse
initialization is applied. Default: False.
backcast_loss_ratio: weight of backcast in comparison to forecast when
calculating the loss. A weight of 1.0 means that forecast and
backcast loss is weighted the same (regardless of backcast and forecast
lengths). Defaults to 0.0, i.e. no weight.
loss: loss to optimize. Defaults to MASE().
log_gradient_flow: if to log gradient flow, this takes time and should be only done to diagnose training
failures
reduce_on_plateau_patience (int): patience after which learning rate is reduced by a factor of 10
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that are logged during training.
Defaults to nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
log_gradient_flow: if to log gradient flow, this takes time and should be
only done to diagnose training failures.
reduce_on_plateau_patience (int): patience after which learning rate is
reduced by a factor of 10
logging_metrics (nn.ModuleList[MultiHorizonMetric]): list of metrics that
are logged during training. Defaults to
nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
**kwargs: additional arguments to :py:class:`~BaseModel`.
""" # noqa: E501
if expansion_coefficient_lengths is None:
Expand All @@ -107,7 +158,24 @@ def __init__(
logging_metrics = nn.ModuleList([SMAPE(), MAE(), RMSE(), MAPE(), MASE()])
if loss is None:
loss = MASE()
self.save_hyperparameters()
# Bundle KAN parameters into a dictionary
self.kan_params = {
"use_kan": use_kan,
"num_grids": num_grids,
"k": k,
"noise_scale": noise_scale,
"scale_base_mu": scale_base_mu,
"scale_base_sigma": scale_base_sigma,
"scale_sp": scale_sp,
"base_fun": base_fun,
"grid_eps": grid_eps,
"grid_range": grid_range,
"sp_trainable": sp_trainable,
"sb_trainable": sb_trainable,
"sparse_init": sparse_init,
}

self.save_hyperparameters(ignore=["loss", "logging_metrics"])
super().__init__(loss=loss, logging_metrics=logging_metrics, **kwargs)

# setup stacks
Expand All @@ -122,6 +190,7 @@ def __init__(
backcast_length=context_length,
forecast_length=prediction_length,
dropout=self.hparams.dropout,
kan_params=self.kan_params,
)
elif stack_type == "seasonality":
net_block = NBEATSSeasonalBlock(
Expand All @@ -131,6 +200,7 @@ def __init__(
forecast_length=prediction_length,
min_period=self.hparams.expansion_coefficient_lengths[stack_id],
dropout=self.hparams.dropout,
kan_params=self.kan_params,
)
elif stack_type == "trend":
net_block = NBEATSTrendBlock(
Expand All @@ -140,6 +210,7 @@ def __init__(
backcast_length=context_length,
forecast_length=prediction_length,
dropout=self.hparams.dropout,
kan_params=self.kan_params,
)
else:
raise ValueError(f"Unknown stack type {stack_type}")
Expand Down Expand Up @@ -223,7 +294,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
@classmethod
def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
"""
Convenience function to create network from :py:class`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.
Convenience function to create network from :py:class
`~pytorch_forecasting.data.timeseries.TimeSeriesDataSet`.

Args:
dataset (TimeSeriesDataSet): dataset where sole predictor is the target.
Expand Down Expand Up @@ -359,10 +431,11 @@ def plot_interpretation(
x (Dict[str, torch.Tensor]): network input
output (Dict[str, torch.Tensor]): network output
idx (int): index of sample for which to plot the interpretation.
ax (List[matplotlib axes], optional): list of two matplotlib axes onto which to plot the interpretation.
Defaults to None.
plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot seasonality and
generic forecast on secondary axis in second panel. Defaults to False.
ax (List[matplotlib axes], optional): list of two matplotlib axes onto which
to plot the interpretation. Defaults to None.
plot_seasonality_and_generic_on_secondary_axis (bool, optional): if to plot
seasonality and generic forecast on secondary axis in second panel.
Defaults to False.

Returns:
plt.Figure: matplotlib figure
Expand Down
Loading
Loading