Skip to content

Conversation

@AnMakc
Copy link

@AnMakc AnMakc commented Nov 7, 2025

This PR fixes patch normalization issue in TimesFM model:

  • Numerical instability for low-variance data
  • Incorrect clamping of calculated sigma to one instead of configured eps

Issue only affects data with low relative variance and results in incorrect prediction.
Variance is now calculated after mean removal to avoid catastrophic cancellation.

@kashif Please, take a look as you contributed this model.

Bug demo

image
Code to reproduce
import random

import numpy as np
import matplotlib.pyplot as plt
import torch

from transformers import TimesFmModelForPrediction


seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)


model = TimesFmModelForPrediction.from_pretrained(
    "google/timesfm-2.0-500m-pytorch",
    attn_implementation="sdpa",
    device_map="auto",
)

model.eval()


X_inp = np.arange(256)
X_pred = np.arange(256, 256 + 128)

fig, axs = plt.subplots(2, 2, figsize=(12, 12))

for ax, sigma in zip(axs.flatten(), [1e-0, 1e-1, 1e-2, 1e-3]):
    rng = np.random.default_rng(seed=seed)
    Y_inp = 1e3 + rng.normal(loc=0, scale=sigma, size=(256,)).cumsum()

    inputs = [
        torch.from_numpy(Y_inp).to(model.device, dtype=model.dtype)
    ]

    with torch.no_grad():
        pred = model(past_values=inputs)
        pred = pred.mean_predictions.cpu().numpy().squeeze()

    ax.plot(X_inp, Y_inp, label='Inputs')
    ax.plot(X_pred, pred, label='Mean Prediction')
    ax.set_title(f'Sigma={sigma}')
    ax.legend()

@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: timesfm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant