Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Torch Module Parameters Gradient Estimator #31

Merged
merged 11 commits into from
Sep 18, 2024
66 changes: 66 additions & 0 deletions ldp/graph/gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from functools import partial
from typing import Any

import torch
import tree

from ldp.graph.op_utils import CallID
from ldp.graph.ops import GradInType, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,3 +144,67 @@ def assign_default_gradients(

tree.assert_same_structure(input_grads, input_args_kwargs)
return input_grads


class TorchParamBackwardEstimator:
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, module: torch.nn.Module):
self.module = module
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
self.params = dict(module.named_parameters())

def backward(
self,
ctx: OpCtx,
input_args: list[ResultOrValue],
input_kwargs: dict[str, ResultOrValue],
grad_output: tree.Structure,
call_id: CallID,
) -> GradInType:
if ctx.op_name != "TorchOp":
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
f"Attempted to use TorchParamBackwardEstimator with non-TorchOp operation {ctx.op_name}"
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
)

tensor_args, tensor_kwargs = ctx.get(call_id, TorchOp.CTX_TENSOR_INPUT_KEY)
n_pos_args = len(tensor_args)
n_pos_kwargs = len(tensor_kwargs)
output = ctx.get(call_id, "output").value

if not isinstance(grad_output, torch.Tensor):
grad_output = torch.tensor(
grad_output, dtype=output.dtype, device=output.device
)

while grad_output.ndim < output.ndim:
# Assume we can broadcast, so expand dims
# e.g. if output.shape = (2, 1, 1) and grad_output is a scalar
# then we want to expand to (1, 1, 1) and then broadcast
grad_output = grad_output.unsqueeze(-1)

gradients = torch.autograd.grad(
output,
[*tensor_args, *tensor_kwargs.values(), *list(self.params.values())],
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
grad_outputs=grad_output,
)

grad_args = [
grad.detach().cpu().float() if grad is not None else None # type: ignore[redundant-expr]
for grad in gradients[:n_pos_args]
]
grad_kwargs = {
k: grad.detach().cpu().float()
if grad is not None # type: ignore[redundant-expr]
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
else None
for k, grad in zip(
tensor_kwargs.keys(), gradients[n_pos_args:n_pos_kwargs], strict=True
)
}
grads_params = {
name: grad.detach().cpu().float()
for name, grad in zip(
self.params.keys(), gradients[n_pos_kwargs:], strict=True
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
)
}

ctx.update(call_id=call_id, key="grads_params", value=grads_params)

return grad_args, grad_kwargs
23 changes: 23 additions & 0 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

import numpy as np
import pytest
import torch
import tree

from ldp.graph.common_ops import ConfigOp, FxnOp
from ldp.graph.gradient_estimators import (
TorchParamBackwardEstimator,
assign_constant_grads,
assign_default_grads,
)
from ldp.graph.op_utils import CallID, compute_graph
from ldp.graph.ops import GradInType, Op, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp


class PoissonSamplerOp(Op):
Expand Down Expand Up @@ -377,3 +380,23 @@ async def test_serial_ops_diff_run_id():

with pytest.raises(RuntimeError, match="args and kwargs must have the same run_id"):
await op2(result1)


@pytest.mark.asyncio
async def test_torch_param_backward_estimator():
torch_module = torch.nn.Linear(4, 1)
torch_op = TorchOp(torch_module)
estimator = TorchParamBackwardEstimator(torch_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good! Following our conversation this morning, do you mind testing for output dim >1? It's worth checking that the shapes are correct in that case too, since we apply a few squeezes in the code.


# Forward pass
result = await torch_op(torch.randn(4, requires_grad=True))

# Backward pass
result.compute_grads(backward_fns={"TorchOp": estimator.backward})

# Check that the gradients are computed and have the correct shape
call_ids = torch_op.get_call_ids({result.call_id.run_id})
grad_params = torch_op.ctx.get(next(iter(call_ids)), "grads_params")
for named_param, grad_param in torch_module.named_parameters():
assert named_param in grad_params
assert grad_param.shape == grad_params[named_param].shape