diff --git a/ldp/graph/gradient_estimators.py b/ldp/graph/gradient_estimators.py index 6041a36f..3e44affa 100644 --- a/ldp/graph/gradient_estimators.py +++ b/ldp/graph/gradient_estimators.py @@ -6,10 +6,12 @@ from functools import partial from typing import Any +import torch import tree from ldp.graph.op_utils import CallID from ldp.graph.ops import GradInType, OpCtx, OpResult, ResultOrValue +from ldp.graph.torch_ops import TorchOp logger = logging.getLogger(__name__) @@ -142,3 +144,88 @@ def assign_default_gradients( tree.assert_same_structure(input_grads, input_args_kwargs) return input_grads + + +class TorchParamBackwardEstimator: + """ + Gradient estimator for `TorchOp` internal parameters. + + This estimator computes gradients with respect to the internal parameters of a + `torch.nn.Module` by calling the `backward` method of the estimator instead of the default + `backward` method of `TorchOp`. Computed gradients are stored in the context of the operation + under the key `"grad_params"`. + + Examples: + >>> torch_module = torch.nn.Sequential( + ... torch.nn.Linear(4, 4), + ... torch.nn.Linear(4, 1), + ... ) + >>> torch_op = TorchOp(torch_module) + >>> estimator = TorchParamBackwardEstimator(torch_module) + >>> result = await torch_op(torch.randn(4, requires_grad=True)) + >>> result.compute_grads(backward_fns={"TorchOp": estimator.backward}) + + Note: + This estimator is only compatible with `TorchOp` operations. + """ + + def __init__(self, module: torch.nn.Module): + self.params = dict(module.named_parameters()) + + def backward( + self, + ctx: OpCtx, + input_args: list[ResultOrValue], + input_kwargs: dict[str, ResultOrValue], + grad_output: tree.Structure, + call_id: CallID, + ) -> GradInType: + if ctx.op_name != TorchOp.__name__: + raise RuntimeError( + f"Attempted to use TorchParamBackwardEstimator with non-TorchOp operation {ctx.op_name}." + ) + + tensor_args, tensor_kwargs = ctx.get(call_id, TorchOp.CTX_TENSOR_INPUT_KEY) + n_pos_args = len(tensor_args) + n_pos_kwargs = len(tensor_kwargs) + output = ctx.get(call_id, "output").value + + if not isinstance(grad_output, torch.Tensor): + grad_output = torch.tensor( + grad_output, dtype=output.dtype, device=output.device + ) + + while grad_output.ndim < output.ndim: + # Assume we can broadcast, so expand dims + # e.g. if output.shape = (2, 1, 1) and grad_output is a scalar + # then we want to expand to (1, 1, 1) and then broadcast + grad_output = grad_output.unsqueeze(-1) + + if output.shape != grad_output.shape: + raise RuntimeError( + f"Output shape {output.shape} does not match grad_output shape {grad_output.shape}" + ) + + gradients = torch.autograd.grad( + output, + [*tensor_args, *tensor_kwargs.values(), *self.params.values()], + grad_outputs=grad_output, + ) + + grad_args = [grad.detach().cpu().float() for grad in gradients[:n_pos_args]] + grad_kwargs = { + k: grad.detach().cpu().float() + for k, grad in zip( + tensor_kwargs.keys(), gradients[n_pos_args:n_pos_kwargs], strict=True + ) + } + grad_params = { + name: grad.detach().cpu().float() + for name, grad in zip( + self.params.keys(), gradients[n_pos_kwargs:], strict=True + ) + } + + ctx.update(call_id=call_id, key="grad_params", value=grad_params) + + return grad_args, grad_kwargs diff --git a/ldp/graph/torch_ops.py b/ldp/graph/torch_ops.py index 07cd3ed1..c8fb2e15 100644 --- a/ldp/graph/torch_ops.py +++ b/ldp/graph/torch_ops.py @@ -83,6 +83,11 @@ def backward( # then we want to expand to (1, 1, 1) and then broadcast grad_output = grad_output.unsqueeze(-1) + if output.shape != grad_output.shape: + raise RuntimeError( + f"Output shape {output.shape} does not match grad_output shape {grad_output.shape}" + ) + gradients = torch.autograd.grad( output, [*tensor_args, *tensor_kwargs.values()], @@ -91,14 +96,9 @@ def backward( retain_graph=True, ) - grad_args = [ - grad.detach().cpu().float() if grad is not None else None # type: ignore[redundant-expr] - for grad in gradients[:n_pos_args] - ] + grad_args = [grad.detach().cpu().float() for grad in gradients[:n_pos_args]] grad_kwargs = { k: grad.detach().cpu().float() - if grad is not None # type: ignore[redundant-expr] - else None for k, grad in zip( tensor_kwargs.keys(), gradients[n_pos_args:], strict=True ) diff --git a/tests/test_gradients.py b/tests/test_gradients.py index 51e93cf7..272bc1a7 100644 --- a/tests/test_gradients.py +++ b/tests/test_gradients.py @@ -7,15 +7,18 @@ import numpy as np import pytest +import torch import tree from ldp.graph.common_ops import ConfigOp, FxnOp from ldp.graph.gradient_estimators import ( + TorchParamBackwardEstimator, assign_constant_grads, assign_default_grads, ) from ldp.graph.op_utils import CallID, compute_graph from ldp.graph.ops import GradInType, Op, OpCtx, OpResult, ResultOrValue +from ldp.graph.torch_ops import TorchOp class PoissonSamplerOp(Op): @@ -377,3 +380,28 @@ async def test_serial_ops_diff_run_id(): with pytest.raises(RuntimeError, match="args and kwargs must have the same run_id"): await op2(result1) + + +@pytest.mark.parametrize("hidden_nodes", [1, 4]) +@pytest.mark.asyncio +async def test_torch_param_backward_estimator(hidden_nodes: int): + torch_module = torch.nn.Sequential( + torch.nn.Linear(4, hidden_nodes), + torch.nn.Linear(hidden_nodes, 1), + ) + torch_op = TorchOp(torch_module) + estimator = TorchParamBackwardEstimator(torch_module) + + # Forward pass + result = await torch_op(torch.randn(4, requires_grad=True)) + + # Backward pass + result.compute_grads(backward_fns={"TorchOp": estimator.backward}) + + # Check that the gradients are computed and have the correct shape + call_ids = torch_op.get_call_ids({result.call_id.run_id}) + grad_params = torch_op.ctx.get(next(iter(call_ids)), "grad_params") + + for named_param, grad_param in torch_module.named_parameters(): + assert named_param in grad_params + assert grad_param.shape == grad_params[named_param].shape