diff --git a/src/chronos/chronos_bolt.py b/src/chronos/chronos_bolt.py index f099e04..8ad3172 100644 --- a/src/chronos/chronos_bolt.py +++ b/src/chronos/chronos_bolt.py @@ -363,7 +363,7 @@ def forward( ) * target_mask.float() ) - loss = loss.mean(dim=-2) # Mean over prediction horizon + loss = loss.mean(dim=-1) # Mean over prediction horizon loss = loss.sum(dim=-1) # Sum over quantile levels loss = loss.mean() # Mean over batch