-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Torch Module Parameters Gradient Estimator #31
Conversation
tests/test_gradients.py
Outdated
async def test_torch_param_backward_estimator(): | ||
torch_module = torch.nn.Linear(4, 1) | ||
torch_op = TorchOp(torch_module) | ||
estimator = TorchParamBackwardEstimator(torch_module) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good! Following our conversation this morning, do you mind testing for output dim >1? It's worth checking that the shapes are correct in that case too, since we apply a few squeezes in the code.
I extended the tests. I noticed that at the end of the graph we set output_grads to [0] And then we try to expand the dimensions in TorchOp But |
Co-authored-by: James Braza <jamesbraza@gmail.com>
Co-authored-by: James Braza <jamesbraza@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question and a comment then it looks good to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work here!
Adds a
TorchParamBackwardEstimator
class, which includes abackward
method. This method can be used to monkey-patch theTorchOp
class during gradient computation. The patch allows the gradients to be computed with respect to the internal parameters of theTorchOp
'sunderlying torch.nn.Module
.