Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShiftNet loss #8

Open
veseln opened this issue Feb 9, 2021 · 1 comment
Open

ShiftNet loss #8

veseln opened this issue Feb 9, 2021 · 1 comment

Comments

@veseln
Copy link

veseln commented Feb 9, 2021

Hello,
First of all, thank you for publishing the code for this model, it is much appreciated.
I have a question related to the ShiftNet loss, according to the paper (page 6, equation 5) the ShiftNet loss should be the L2 norm of the shifts. However looking at the line in train.py where this is calculated
https://github.com/ElementAI/HighRes-net/blob/aa022561c55ee3875b7d2e7837d37e620bf19315/src/train.py#L187
the ShiftNet loss is torch.mean(shifts)**2, which is not the same thing. One of the potential issues with the loss defined as it is, is that shifts can be both positive or negative and so they can cancel out across the batches.

Is this a bug or was this intended? Were the results presented in the paper obtained with the loss as is defined in the code or as is defined in the paper?

Thank you!

@alkalait
Copy link
Contributor

alkalait commented Feb 9, 2021

Hello @veseln - thanks for raising this issue. I can confirm that this is indeed a bug.

You are entirely correct on the discrepancy between equation 5
image
and its implementation
https://github.com/ElementAI/HighRes-net/blob/aa022561c55ee3875b7d2e7837d37e620bf19315/src/train.py#L187

It seems to be a typo, and more than likely this is how the results were obtained.

A few notes:

  1. I can confirm that the intention of Line 187 was to regularize the norm of the shifts (output by ShiftNet). As such, Line 187 is not strictly speaking the ShiftNet loss (or anything's loss for that matter). The ShiftNet parameters are learned by virtue of the quality of the reconstruction (the cPSNR-based loss). The regularizer was meant to keep the predicted shifts from deviating too much.

  2. The bad news is that, as you rightly suggested, the mean of all shifts will tend to zero, and therefore its square will converge quadratically faster to zero. All this does is regularize the mean magnitude, so it's fair to say that Line 187 is totally moot.

  3. The good bit of news is that, in retrospect, although we have not been regularizing the ShiftNet output all along, the end-to-end pipeline can work without this regulatization. The question remains however, on how the regularization affects overall performance. I would be grateful if you performed an ablation with the corrected line, and reported your results here.

In summary
ShiftNet is still learned to accommodate for the shifts between the SR and HR ground-truth. But the predicted shifts are effectively not regularized at all.

Hope this helps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants