Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Torch Module Parameters Gradient Estimator #31

Merged
merged 11 commits into from
Sep 18, 2024
87 changes: 87 additions & 0 deletions ldp/graph/gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from functools import partial
from typing import Any

import torch
import tree

from ldp.graph.op_utils import CallID
from ldp.graph.ops import GradInType, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,3 +144,88 @@ def assign_default_gradients(

tree.assert_same_structure(input_grads, input_args_kwargs)
return input_grads


class TorchParamBackwardEstimator:
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
"""
Gradient estimator for `TorchOp` internal parameters.

This estimator computes gradients with respect to the internal parameters of a
`torch.nn.Module` by calling the `backward` method of the estimator instead of the default
`backward` method of `TorchOp`. Computed gradients are stored in the context of the operation
under the key `"grad_params"`.

Examples:
>>> torch_module = torch.nn.Sequential(
... torch.nn.Linear(4, 4),
... torch.nn.Linear(4, 1),
... )
>>> torch_op = TorchOp(torch_module)
>>> estimator = TorchParamBackwardEstimator(torch_module)
>>> result = await torch_op(torch.randn(4, requires_grad=True))
>>> result.compute_grads(backward_fns={"TorchOp": estimator.backward})

Note:
This estimator is only compatible with `TorchOp` operations.
"""

def __init__(self, module: torch.nn.Module):
self.params = dict(module.named_parameters())

def backward(
self,
ctx: OpCtx,
input_args: list[ResultOrValue],
input_kwargs: dict[str, ResultOrValue],
grad_output: tree.Structure,
call_id: CallID,
) -> GradInType:
if ctx.op_name != TorchOp.__name__:
raise RuntimeError(
f"Attempted to use TorchParamBackwardEstimator with non-TorchOp operation {ctx.op_name}."
)

tensor_args, tensor_kwargs = ctx.get(call_id, TorchOp.CTX_TENSOR_INPUT_KEY)
n_pos_args = len(tensor_args)
n_pos_kwargs = len(tensor_kwargs)
output = ctx.get(call_id, "output").value

if not isinstance(grad_output, torch.Tensor):
grad_output = torch.tensor(
grad_output, dtype=output.dtype, device=output.device
)

while grad_output.ndim < output.ndim:
# Assume we can broadcast, so expand dims
# e.g. if output.shape = (2, 1, 1) and grad_output is a scalar
# then we want to expand to (1, 1, 1) and then broadcast
grad_output = grad_output.unsqueeze(-1)

if output.shape != grad_output.shape:
raise RuntimeError(
f"Output shape {output.shape} does not match grad_output shape {grad_output.shape}"
)

gradients = torch.autograd.grad(
output,
[*tensor_args, *tensor_kwargs.values(), *self.params.values()],
grad_outputs=grad_output,
)

grad_args = [grad.detach().cpu().float() for grad in gradients[:n_pos_args]]
grad_kwargs = {
k: grad.detach().cpu().float()
for k, grad in zip(
tensor_kwargs.keys(), gradients[n_pos_args:n_pos_kwargs], strict=True
)
}
grad_params = {
name: grad.detach().cpu().float()
for name, grad in zip(
self.params.keys(), gradients[n_pos_kwargs:], strict=True
albertbou92 marked this conversation as resolved.
Show resolved Hide resolved
)
}

ctx.update(call_id=call_id, key="grad_params", value=grad_params)

return grad_args, grad_kwargs
12 changes: 6 additions & 6 deletions ldp/graph/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def backward(
# then we want to expand to (1, 1, 1) and then broadcast
grad_output = grad_output.unsqueeze(-1)

if output.shape != grad_output.shape:
raise RuntimeError(
f"Output shape {output.shape} does not match grad_output shape {grad_output.shape}"
)

gradients = torch.autograd.grad(
output,
[*tensor_args, *tensor_kwargs.values()],
Expand All @@ -91,14 +96,9 @@ def backward(
retain_graph=True,
)

grad_args = [
grad.detach().cpu().float() if grad is not None else None # type: ignore[redundant-expr]
for grad in gradients[:n_pos_args]
]
grad_args = [grad.detach().cpu().float() for grad in gradients[:n_pos_args]]
grad_kwargs = {
k: grad.detach().cpu().float()
if grad is not None # type: ignore[redundant-expr]
else None
for k, grad in zip(
tensor_kwargs.keys(), gradients[n_pos_args:], strict=True
)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

import numpy as np
import pytest
import torch
import tree

from ldp.graph.common_ops import ConfigOp, FxnOp
from ldp.graph.gradient_estimators import (
TorchParamBackwardEstimator,
assign_constant_grads,
assign_default_grads,
)
from ldp.graph.op_utils import CallID, compute_graph
from ldp.graph.ops import GradInType, Op, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp


class PoissonSamplerOp(Op):
Expand Down Expand Up @@ -377,3 +380,28 @@ async def test_serial_ops_diff_run_id():

with pytest.raises(RuntimeError, match="args and kwargs must have the same run_id"):
await op2(result1)


@pytest.mark.parametrize("hidden_nodes", [1, 4])
@pytest.mark.asyncio
async def test_torch_param_backward_estimator(hidden_nodes: int):
torch_module = torch.nn.Sequential(
torch.nn.Linear(4, hidden_nodes),
torch.nn.Linear(hidden_nodes, 1),
)
torch_op = TorchOp(torch_module)
estimator = TorchParamBackwardEstimator(torch_module)

# Forward pass
result = await torch_op(torch.randn(4, requires_grad=True))

# Backward pass
result.compute_grads(backward_fns={"TorchOp": estimator.backward})

# Check that the gradients are computed and have the correct shape
call_ids = torch_op.get_call_ids({result.call_id.run_id})
grad_params = torch_op.ctx.get(next(iter(call_ids)), "grad_params")

for named_param, grad_param in torch_module.named_parameters():
assert named_param in grad_params
assert grad_param.shape == grad_params[named_param].shape