Skip to content

[Feature] Torch Module Parameters Gradient Estimator #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions ldp/graph/gradient_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from functools import partial
from typing import Any

import torch
import tree

from ldp.graph.op_utils import CallID
from ldp.graph.ops import GradInType, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,3 +144,65 @@ def assign_default_gradients(

tree.assert_same_structure(input_grads, input_args_kwargs)
return input_grads


class TorchParamBackwardEstimator:
def __init__(self, module: torch.nn.Module):
self.module = module
self.params = dict(module.named_parameters())

def backward(
self,
ctx: OpCtx,
input_args: list[ResultOrValue],
input_kwargs: dict[str, ResultOrValue],
grad_output: tree.Structure,
call_id: CallID,
) -> GradInType:
if ctx.op_name != "TorchOp":
raise ValueError(
f"Attempted to use TorchParamBackwardEstimator with non-TorchOp operation {ctx.op_name}"
)

tensor_args, tensor_kwargs = ctx.get(call_id, TorchOp.CTX_TENSOR_INPUT_KEY)
n_pos_args = len(tensor_args)
n_pos_kwargs = len(tensor_kwargs)
output = ctx.get(call_id, "output").value

if not isinstance(grad_output, torch.Tensor):
grad_output = torch.tensor(
grad_output, dtype=output.dtype, device=output.device
)

if output.shape != grad_output.shape:
# Should only occur if end of graph is not a scalar, where a sentinel [0] is used
grad_output = torch.zeros_like(output)

gradients = torch.autograd.grad(
output,
[*tensor_args, *tensor_kwargs.values(), *list(self.params.values())],
grad_outputs=grad_output,
)

grad_args = [
grad.detach().cpu().float() if grad is not None else None # type: ignore[redundant-expr]
for grad in gradients[:n_pos_args]
]
grad_kwargs = {
k: grad.detach().cpu().float()
if grad is not None # type: ignore[redundant-expr]
else None
for k, grad in zip(
tensor_kwargs.keys(), gradients[n_pos_args:n_pos_kwargs], strict=True
)
}
grads_params = {
name: grad.detach().cpu().float()
for name, grad in zip(
self.params.keys(), gradients[n_pos_kwargs:], strict=True
)
}

ctx.update(call_id=call_id, key="grads_params", value=grads_params)

return grad_args, grad_kwargs
8 changes: 3 additions & 5 deletions ldp/graph/torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,9 @@ def backward(
grad_output, dtype=output.dtype, device=output.device
)

while grad_output.ndim < output.ndim:
# Assume we can broadcast, so expand dims
# e.g. if output.shape = (2, 1, 1) and grad_output is a scalar
# then we want to expand to (1, 1, 1) and then broadcast
grad_output = grad_output.unsqueeze(-1)
if output.shape != grad_output.shape:
# Should only occur if end of graph is not a scalar, where a sentinel [0] is used
grad_output = torch.zeros_like(output)

gradients = torch.autograd.grad(
output,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@

import numpy as np
import pytest
import torch
import tree

from ldp.graph.common_ops import ConfigOp, FxnOp
from ldp.graph.gradient_estimators import (
TorchParamBackwardEstimator,
assign_constant_grads,
assign_default_grads,
)
from ldp.graph.op_utils import CallID, compute_graph
from ldp.graph.ops import GradInType, Op, OpCtx, OpResult, ResultOrValue
from ldp.graph.torch_ops import TorchOp


class PoissonSamplerOp(Op):
Expand Down Expand Up @@ -377,3 +380,24 @@ async def test_serial_ops_diff_run_id():

with pytest.raises(RuntimeError, match="args and kwargs must have the same run_id"):
await op2(result1)


@pytest.mark.parametrize("output_nodes", [4])
@pytest.mark.asyncio
async def test_torch_param_backward_estimator(output_nodes: int):
torch_module = torch.nn.Linear(4, output_nodes)
torch_op = TorchOp(torch_module)
estimator = TorchParamBackwardEstimator(torch_module)

# Forward pass
result = await torch_op(torch.randn(4, requires_grad=True))

# Backward pass
result.compute_grads(backward_fns={"TorchOp": estimator.backward})

# Check that the gradients are computed and have the correct shape
call_ids = torch_op.get_call_ids({result.call_id.run_id})
grad_params = torch_op.ctx.get(next(iter(call_ids)), "grads_params")
for named_param, grad_param in torch_module.named_parameters():
assert named_param in grad_params
assert grad_param.shape == grad_params[named_param].shape