You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
First of all, thank you for publishing the code for this model, it is much appreciated.
I have a question related to the ShiftNet loss, according to the paper (page 6, equation 5) the ShiftNet loss should be the L2 norm of the shifts. However looking at the line in train.py where this is calculated https://github.com/ElementAI/HighRes-net/blob/aa022561c55ee3875b7d2e7837d37e620bf19315/src/train.py#L187
the ShiftNet loss is torch.mean(shifts)**2, which is not the same thing. One of the potential issues with the loss defined as it is, is that shifts can be both positive or negative and so they can cancel out across the batches.
Is this a bug or was this intended? Were the results presented in the paper obtained with the loss as is defined in the code or as is defined in the paper?
Thank you!
The text was updated successfully, but these errors were encountered:
It seems to be a typo, and more than likely this is how the results were obtained.
A few notes:
I can confirm that the intention of Line 187 was to regularize the norm of the shifts (output by ShiftNet). As such, Line 187 is not strictly speaking the ShiftNet loss (or anything's loss for that matter). The ShiftNet parameters are learned by virtue of the quality of the reconstruction (the cPSNR-based loss). The regularizer was meant to keep the predicted shifts from deviating too much.
The bad news is that, as you rightly suggested, the mean of all shifts will tend to zero, and therefore its square will converge quadratically faster to zero. All this does is regularize the mean magnitude, so it's fair to say that Line 187 is totally moot.
The good bit of news is that, in retrospect, although we have not been regularizing the ShiftNet output all along, the end-to-end pipeline can work without this regulatization. The question remains however, on how the regularization affects overall performance. I would be grateful if you performed an ablation with the corrected line, and reported your results here.
In summary
ShiftNet is still learned to accommodate for the shifts between the SR and HR ground-truth. But the predicted shifts are effectively not regularized at all.
Hello,
First of all, thank you for publishing the code for this model, it is much appreciated.
I have a question related to the ShiftNet loss, according to the paper (page 6, equation 5) the ShiftNet loss should be the L2 norm of the shifts. However looking at the line in train.py where this is calculated
https://github.com/ElementAI/HighRes-net/blob/aa022561c55ee3875b7d2e7837d37e620bf19315/src/train.py#L187
the ShiftNet loss is
torch.mean(shifts)**2
, which is not the same thing. One of the potential issues with the loss defined as it is, is that shifts can be both positive or negative and so they can cancel out across the batches.Is this a bug or was this intended? Were the results presented in the paper obtained with the loss as is defined in the code or as is defined in the paper?
Thank you!
The text was updated successfully, but these errors were encountered: