Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Backwards implementaion #24

Open
HuuYuLong opened this issue May 11, 2024 · 2 comments
Open

The Backwards implementaion #24

HuuYuLong opened this issue May 11, 2024 · 2 comments

Comments

@HuuYuLong
Copy link

Hello,

Thanks for the excellent work! !

I'm still confused about the Backwards implementation.
The gradient of iFFT(FFT(x)) should be FFT(iFFT(dout)). But what is the principle of the gradient dx, dy, for the gradient of iFFT(FFT(x) * FFT(y))?

Thanks!

@DanFu09
Copy link
Contributor

DanFu09 commented May 11, 2024

Hello! Hopefully it helps to work through the backprop calculation by hand.

First, I'll note that iFFT(FFT(x)) = x, so if out = iFFT(FFT(x)), then the gradient dx = dout = FFT(iFFT(dout)).

Now let's look at iFFT(FFT(x) * FFT(y)).

Let's split this into a couple steps:

xf = FFT(x)
yf = FFT(y)
zf = xf * yf
out = iFFT(zf)

Hopefully you can see that this out is the same as iFFT(FFT(x) * FFT(y)).

And now let's work through each step of this by hand, going backwards. Note that I'm being a bit fast and loose with the constants here (in particular, I may be missing a factor of 1/n in a couple places).

dzf = FFT(dout)  # differentiate iFFT
dxf = dzf * yf   # product rule
dyf = dzf * xf   # product rule
dx = iFFT(dxf)   # differentiate FFT
dy = iFFT(dyf)   # differentiate FFT

So putting it together, we have:

dx = iFFT(FFT(dout) * FFT(y))
dy = iFFT(FFT(x) * FFT(dout))

Hope this helps!

@HuuYuLong
Copy link
Author

Thanks again, it's really helpful!!!

@HuuYuLong HuuYuLong reopened this May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants