Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the learning rate setting of p_conv and m_conv #7

Open
dontLoveBugs opened this issue Feb 21, 2019 · 7 comments
Open

About the learning rate setting of p_conv and m_conv #7

dontLoveBugs opened this issue Feb 21, 2019 · 7 comments

Comments

@dontLoveBugs
Copy link

dontLoveBugs commented Feb 21, 2019

You set the gradient of p_conv and m_conv to 0.1 times the other layers, but I find the gradient has no change after backward.
I use the following code to test.

    def _set_lr(module, grad_input, grad_output):
        print('grad input:', grad_input)
        print('grad output:', grad_output)
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
    x = torch.randn(4, 3, 5, 5)
    y_ = torch.randn(4, 1, 5, 5)
    loss = nn.L1Loss()

    d_conv = DeformConv2d(inc=3, outc=1, modulation=True)

    y = d_conv.forward(x)
    l = loss(y, y_)
    l.backward()

    print('p conv grad:')
    print(d_conv.p_conv.weight.grad)
    print('m conv grad:')
    print(d_conv.m_conv.weight.grad)
    print('conv grad:')
    print(d_conv.conv.weight.grad)

The gradient of p_conv is same with the grad_input, but I think the gradient of p_conv is 0.1 times the gradient of the grad_input. Am I wrong?
image
image

@4uiiurz1
Copy link
Owner

You're right!
I'll fix it.

@BananaLv26
Copy link

You're right!
I'll fix it.

Have you solved this problem now?

@jszgz
Copy link

jszgz commented May 28, 2020

@dontLoveBugs Hello, can you review my issue ? I think the bilinear kernel is wrong

@zcong17huang
Copy link

You're right!
I'll fix it.

'tuple' object can not be modified. Your code just get an generator.

@XinZhangRadar
Copy link

I have searched online, the grad of output can not be modified, if you want modify the grad of input, you need to return the modified grad of input , like :
def _set_lr(module, grad_input, grad_output):
return (grad_input[i] * 0.1 for i in range(len(grad_input)))

you can try it. My question is :
Why change the p_conv gradients, Is it to avoid affecting the learning of another feature extraction branch?

@steven22tom
Copy link

@XinZhangNLPR the you is becuse the backforward_hook expected tuple, not 'generator'

I have searched online, the grad of output can not be modified, if you want modify the grad of input, you need to return the modified grad of input , like :
def _set_lr(module, grad_input, grad_output):
return (grad_input[i] * 0.1 for i in range(len(grad_input)))

you can try it. My question is :
Why change the p_conv gradients, Is it to avoid affecting the learning of another feature extraction branch?

Your suggestion still return a generator not a tuple

@YXB-NKU
Copy link

YXB-NKU commented Oct 3, 2023

You're right! I'll fix it.

it seems this bug has not fixed yet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants