Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Torch Module Parameters Gradient Estimator #31

Merged
merged 11 commits into from
Sep 18, 2024

Conversation

albertbou92
Copy link
Contributor

Adds a TorchParamBackwardEstimator class, which includes a backward method. This method can be used to monkey-patch the TorchOp class during gradient computation. The patch allows the gradients to be computed with respect to the internal parameters of the TorchOp's underlying torch.nn.Module.

Comment on lines 386 to 389
async def test_torch_param_backward_estimator():
torch_module = torch.nn.Linear(4, 1)
torch_op = TorchOp(torch_module)
estimator = TorchParamBackwardEstimator(torch_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good! Following our conversation this morning, do you mind testing for output dim >1? It's worth checking that the shapes are correct in that case too, since we apply a few squeezes in the code.

@albertbou92
Copy link
Contributor Author

I extended the tests.

I noticed that at the end of the graph we set output_grads to [0]
https://github.com/Future-House/ldp/blob/main/ldp/graph/ops.py#L102

And then we try to expand the dimensions in TorchOp
https://github.com/Future-House/ldp/blob/main/ldp/graph/torch_ops.py#L80

But grad_outputs seem to require exact same shape as output

ldp/graph/torch_ops.py Outdated Show resolved Hide resolved
ldp/graph/gradient_estimators.py Show resolved Hide resolved
ldp/graph/gradient_estimators.py Outdated Show resolved Hide resolved
ldp/graph/gradient_estimators.py Outdated Show resolved Hide resolved
albertbou92 and others added 3 commits September 17, 2024 12:56
Co-authored-by: James Braza <jamesbraza@gmail.com>
ldp/graph/gradient_estimators.py Outdated Show resolved Hide resolved
ldp/graph/gradient_estimators.py Outdated Show resolved Hide resolved
ldp/graph/gradient_estimators.py Outdated Show resolved Hide resolved
albertbou92 and others added 2 commits September 17, 2024 18:05
Co-authored-by: James Braza <jamesbraza@gmail.com>
Copy link
Collaborator

@jamesbraza jamesbraza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question and a comment then it looks good to me

ldp/graph/gradient_estimators.py Outdated Show resolved Hide resolved
ldp/graph/gradient_estimators.py Show resolved Hide resolved
Copy link
Collaborator

@jamesbraza jamesbraza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work here!

@albertbou92 albertbou92 merged commit 911e685 into main Sep 18, 2024
6 checks passed
@albertbou92 albertbou92 deleted the torch_backward_estimator branch September 18, 2024 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants