Skip to content

Commit

Permalink
[Feature] Torch Module Parameters Gradient Estimator (#31)
Browse files Browse the repository at this point in the history
Co-authored-by: James Braza <jamesbraza@gmail.com>
  • Loading branch information
albertbou92 and jamesbraza authored Sep 18, 2024
1 parent 403adbc commit 911e685
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 6 deletions.
87 changes: 87 additions & 0 deletions ldp/graph/gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from functools import partial
from typing import Any

import torch
import tree

from ldp.graph.op_utils import CallID
from ldp.graph.ops import GradInType, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,3 +144,88 @@ def assign_default_gradients(

tree.assert_same_structure(input_grads, input_args_kwargs)
return input_grads


class TorchParamBackwardEstimator:
"""
Gradient estimator for `TorchOp` internal parameters.
This estimator computes gradients with respect to the internal parameters of a
`torch.nn.Module` by calling the `backward` method of the estimator instead of the default
`backward` method of `TorchOp`. Computed gradients are stored in the context of the operation
under the key `"grad_params"`.
Examples:
>>> torch_module = torch.nn.Sequential(
... torch.nn.Linear(4, 4),
... torch.nn.Linear(4, 1),
... )
>>> torch_op = TorchOp(torch_module)
>>> estimator = TorchParamBackwardEstimator(torch_module)
>>> result = await torch_op(torch.randn(4, requires_grad=True))
>>> result.compute_grads(backward_fns={"TorchOp": estimator.backward})
Note:
This estimator is only compatible with `TorchOp` operations.
"""

def __init__(self, module: torch.nn.Module):
self.params = dict(module.named_parameters())

def backward(
self,
ctx: OpCtx,
input_args: list[ResultOrValue],
input_kwargs: dict[str, ResultOrValue],
grad_output: tree.Structure,
call_id: CallID,
) -> GradInType:
if ctx.op_name != TorchOp.__name__:
raise RuntimeError(
f"Attempted to use TorchParamBackwardEstimator with non-TorchOp operation {ctx.op_name}."
)

tensor_args, tensor_kwargs = ctx.get(call_id, TorchOp.CTX_TENSOR_INPUT_KEY)
n_pos_args = len(tensor_args)
n_pos_kwargs = len(tensor_kwargs)
output = ctx.get(call_id, "output").value

if not isinstance(grad_output, torch.Tensor):
grad_output = torch.tensor(
grad_output, dtype=output.dtype, device=output.device
)

while grad_output.ndim < output.ndim:
# Assume we can broadcast, so expand dims
# e.g. if output.shape = (2, 1, 1) and grad_output is a scalar
# then we want to expand to (1, 1, 1) and then broadcast
grad_output = grad_output.unsqueeze(-1)

if output.shape != grad_output.shape:
raise RuntimeError(
f"Output shape {output.shape} does not match grad_output shape {grad_output.shape}"
)

gradients = torch.autograd.grad(
output,
[*tensor_args, *tensor_kwargs.values(), *self.params.values()],
grad_outputs=grad_output,
)

grad_args = [grad.detach().cpu().float() for grad in gradients[:n_pos_args]]
grad_kwargs = {
k: grad.detach().cpu().float()
for k, grad in zip(
tensor_kwargs.keys(), gradients[n_pos_args:n_pos_kwargs], strict=True
)
}
grad_params = {
name: grad.detach().cpu().float()
for name, grad in zip(
self.params.keys(), gradients[n_pos_kwargs:], strict=True
)
}

ctx.update(call_id=call_id, key="grad_params", value=grad_params)

return grad_args, grad_kwargs
12 changes: 6 additions & 6 deletions ldp/graph/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def backward(
# then we want to expand to (1, 1, 1) and then broadcast
grad_output = grad_output.unsqueeze(-1)

if output.shape != grad_output.shape:
raise RuntimeError(
f"Output shape {output.shape} does not match grad_output shape {grad_output.shape}"
)

gradients = torch.autograd.grad(
output,
[*tensor_args, *tensor_kwargs.values()],
Expand All @@ -91,14 +96,9 @@ def backward(
retain_graph=True,
)

grad_args = [
grad.detach().cpu().float() if grad is not None else None # type: ignore[redundant-expr]
for grad in gradients[:n_pos_args]
]
grad_args = [grad.detach().cpu().float() for grad in gradients[:n_pos_args]]
grad_kwargs = {
k: grad.detach().cpu().float()
if grad is not None # type: ignore[redundant-expr]
else None
for k, grad in zip(
tensor_kwargs.keys(), gradients[n_pos_args:], strict=True
)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

import numpy as np
import pytest
import torch
import tree

from ldp.graph.common_ops import ConfigOp, FxnOp
from ldp.graph.gradient_estimators import (
TorchParamBackwardEstimator,
assign_constant_grads,
assign_default_grads,
)
from ldp.graph.op_utils import CallID, compute_graph
from ldp.graph.ops import GradInType, Op, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp


class PoissonSamplerOp(Op):
Expand Down Expand Up @@ -377,3 +380,28 @@ async def test_serial_ops_diff_run_id():

with pytest.raises(RuntimeError, match="args and kwargs must have the same run_id"):
await op2(result1)


@pytest.mark.parametrize("hidden_nodes", [1, 4])
@pytest.mark.asyncio
async def test_torch_param_backward_estimator(hidden_nodes: int):
torch_module = torch.nn.Sequential(
torch.nn.Linear(4, hidden_nodes),
torch.nn.Linear(hidden_nodes, 1),
)
torch_op = TorchOp(torch_module)
estimator = TorchParamBackwardEstimator(torch_module)

# Forward pass
result = await torch_op(torch.randn(4, requires_grad=True))

# Backward pass
result.compute_grads(backward_fns={"TorchOp": estimator.backward})

# Check that the gradients are computed and have the correct shape
call_ids = torch_op.get_call_ids({result.call_id.run_id})
grad_params = torch_op.ctx.get(next(iter(call_ids)), "grad_params")

for named_param, grad_param in torch_module.named_parameters():
assert named_param in grad_params
assert grad_param.shape == grad_params[named_param].shape

0 comments on commit 911e685

Please sign in to comment.