Skip to content

Commit

Permalink
Rename A to aggregator in backward and mtl_backward (#203)
Browse files Browse the repository at this point in the history
* Rename A to aggregator in backward, mtl_backward and their usages
* Add changelog entry
  • Loading branch information
ValerianRey authored Dec 10, 2024
1 parent 69f56f6 commit a698f68
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 82 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ changes that do not affect the user.

### Changed

- **BREAKING**: Changed the name of the parameter `A` to `aggregator` in `backward` and
`mtl_backward`.
- **BREAKING**: Changed the order of the parameters of `backward` and `mtl_backward` to make it
possible to have a default value for `inputs` and for `shared_params` and `tasks_params`,
respectively. Usages of `backward` and `mtl_backward` that rely on the order between arguments
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ params = [

loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
aggregator = UPGrad()

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
Expand All @@ -92,7 +92,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, A=A)
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
```

Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Define the aggregator that will be used to combine the Jacobian matrix:

.. code-block:: python
A = UPGrad()
aggregator = UPGrad()
In essence, :doc:`UPGrad <../docs/aggregation/upgrad>` projects each gradient onto the dual cone of
the rows of the Jacobian and averages the results. This ensures that locally, no loss will be
Expand Down Expand Up @@ -69,7 +69,7 @@ Perform the Jacobian descent backward pass:

.. code-block:: python
torchjd.backward([loss1, loss2], A)
torchjd.backward([loss1, loss2], aggregator)
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
Jacobian matrix.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ each Jacobian matrix consists of one gradient per loss. In this example, we use
params = model.parameters()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
aggregator = UPGrad()
for x, y in zip(X, Y):
y_hat = model(x)
losses = loss_fn(y_hat, y)
optimizer.zero_grad()
backward(losses, A)
backward(losses, aggregator)
optimizer.step()
Note that in both cases, we use the `torch.optim.SGD
Expand Down
2 changes: 1 addition & 1 deletion docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ The following code example demonstrates a basic multi-task learning setup using
opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, A=UPGrad())
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
opt.step()
def configure_optimizers(self) -> OptimizerLRScheduler:
Expand Down
4 changes: 2 additions & 2 deletions docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
aggregator = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
Expand All @@ -53,7 +53,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks.
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, A=A)
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()
.. note::
Expand Down
9 changes: 5 additions & 4 deletions src/torchjd/autojac/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@

def backward(
tensors: Sequence[Tensor] | Tensor,
A: Aggregator,
aggregator: Aggregator,
inputs: Iterable[Tensor] | None = None,
retain_graph: bool = False,
parallel_chunk_size: int | None = None,
) -> None:
r"""
Computes the Jacobian of all values in ``tensors`` with respect to all ``inputs``. Computes its
aggregation by ``A`` and accumulates it in the ``.grad`` fields of the ``inputs``.
aggregation by the provided ``aggregator`` and accumulates it in the ``.grad`` fields of the
``inputs``.
:param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobian
matrices will have one row for each value of each of these tensors.
:param A: Aggregator used to reduce the Jacobian into a vector.
:param aggregator: Aggregator used to reduce the Jacobian into a vector.
:param inputs: The tensors with respect to which the Jacobian must be computed. These must have
their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors
that were used to compute the ``tensors`` parameter.
Expand Down Expand Up @@ -95,7 +96,7 @@ def backward(
jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph)

# Transform that aggregates the Jacobians.
aggregate = Aggregate(A, inputs)
aggregate = Aggregate(aggregator, inputs)

# Transform that accumulates the result in the .grad field of the inputs.
accumulate = Accumulate(inputs)
Expand Down
6 changes: 3 additions & 3 deletions src/torchjd/autojac/mtl_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def mtl_backward(
losses: Sequence[Tensor],
features: Sequence[Tensor] | Tensor,
A: Aggregator,
aggregator: Aggregator,
tasks_params: Sequence[Iterable[Tensor]] | None = None,
shared_params: Iterable[Tensor] | None = None,
retain_graph: bool = False,
Expand All @@ -45,7 +45,7 @@ def mtl_backward(
:param losses: The task losses. The Jacobian matrix will have one row per loss.
:param features: The last shared representation used for all tasks, as given by the feature
extractor. Should be non-empty.
:param A: Aggregator used to reduce the Jacobian into a vector.
:param aggregator: Aggregator used to reduce the Jacobian into a vector.
:param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags
must be set to ``True``. If not provided, the parameters considered for each task will
default to the leaf tensors that are in the computation graph of its loss, but that were not
Expand Down Expand Up @@ -129,7 +129,7 @@ def mtl_backward(
jac = Jac(features, shared_params, parallel_chunk_size, retain_graph)

# Transform that aggregates the Jacobians.
aggregate = Aggregate(A, shared_params)
aggregate = Aggregate(aggregator, shared_params)

# Transform that accumulates the result in the .grad field of the shared parameters.
accumulate = Accumulate(shared_params)
Expand Down
14 changes: 7 additions & 7 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_basic_usage():
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
optimizer = SGD(model.parameters(), lr=0.1)

A = UPGrad()
aggregator = UPGrad()
input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
target1 = torch.randn(16) # First batch of 16 targets
target2 = torch.randn(16) # Second batch of 16 targets
Expand All @@ -20,7 +20,7 @@ def test_basic_usage():
loss2 = loss_fn(output[:, 1], target2)

optimizer.zero_grad()
torchjd.backward([loss1, loss2], A)
torchjd.backward([loss1, loss2], aggregator)
optimizer.step()


Expand Down Expand Up @@ -62,13 +62,13 @@ def test_iwrm_with_ssjd():

params = model.parameters()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
aggregator = UPGrad()

for x, y in zip(X, Y):
y_hat = model(x)
losses = loss_fn(y_hat, y)
optimizer.zero_grad()
backward(losses, A)
backward(losses, aggregator)
optimizer.step()

test_erm_with_sgd()
Expand All @@ -94,7 +94,7 @@ def test_mtl():

loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
aggregator = UPGrad()

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
Expand All @@ -108,7 +108,7 @@ def test_mtl():
loss2 = loss_fn(output2, target2)

optimizer.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, A=A)
mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
optimizer.step()


Expand Down Expand Up @@ -150,7 +150,7 @@ def training_step(self, batch, batch_idx) -> None:

opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, A=UPGrad())
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
opt.step()

def configure_optimizers(self) -> OptimizerLRScheduler:
Expand Down
48 changes: 24 additions & 24 deletions tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad


@mark.parametrize("A", [Mean(), UPGrad(), MGDA(), Random()])
def test_backward_various_aggregators(A: Aggregator):
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()])
def test_backward_various_aggregators(aggregator: Aggregator):
"""Tests that backward works for various aggregators."""

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
Expand All @@ -21,17 +21,17 @@ def test_backward_various_aggregators(A: Aggregator):
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
y2 = (p1**2).sum() + p2.norm()

backward([y1, y2], A)
backward([y1, y2], aggregator)

for p in params:
assert (p.grad is not None) and (p.shape == p.grad.shape)


@mark.parametrize("A", [Mean(), UPGrad(), MGDA()])
@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)])
@mark.parametrize("manually_specify_inputs", [True, False])
def test_backward_value_is_correct(
A: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool
aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool
):
"""
Tests that the .grad value filled by backward is correct in a simple example of matrix-vector
Expand All @@ -47,15 +47,15 @@ def test_backward_value_is_correct(
else:
inputs = None

backward([output], A, inputs=inputs)
backward([output], aggregator, inputs=inputs)

assert_close(input.grad, A(J))
assert_close(input.grad, aggregator(J))


def test_backward_empty_inputs():
"""Tests that backward does not fill the .grad values if no input is specified."""

A = Mean()
aggregator = Mean()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
Expand All @@ -64,7 +64,7 @@ def test_backward_empty_inputs():
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
y2 = (p1**2).sum() + p2.norm()

backward([y1, y2], A, inputs=[])
backward([y1, y2], aggregator, inputs=[])

for p in params:
assert p.grad is None
Expand All @@ -76,15 +76,15 @@ def test_backward_partial_inputs():
specified as inputs.
"""

A = Mean()
aggregator = Mean()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)

y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
y2 = (p1**2).sum() + p2.norm()

backward([y1, y2], A, inputs=[p1])
backward([y1, y2], aggregator, inputs=[p1])

assert (p1.grad is not None) and (p1.shape == p1.grad.shape)
assert p2.grad is None
Expand All @@ -93,13 +93,13 @@ def test_backward_partial_inputs():
def test_backward_empty_tensors():
"""Tests that backward raises an error when called with an empty list of tensors."""

A = UPGrad()
aggregator = UPGrad()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)

with raises(ValueError):
backward([], A, inputs=[p1, p2])
backward([], aggregator, inputs=[p1, p2])


def test_backward_multiple_tensors():
Expand All @@ -108,7 +108,7 @@ def test_backward_multiple_tensors():
containing the all the values of the original tensors.
"""

A = UPGrad()
aggregator = UPGrad()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
Expand All @@ -117,13 +117,13 @@ def test_backward_multiple_tensors():
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
y2 = (p1**2).sum() + p2.norm()

backward([y1, y2], A, retain_graph=True)
backward([y1, y2], aggregator, retain_graph=True)

param_to_grad = {p: p.grad for p in params}
for p in params:
p.grad = None

backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), A)
backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), aggregator)

for p in params:
assert (p.grad == param_to_grad[p]).all()
Expand All @@ -133,7 +133,7 @@ def test_backward_multiple_tensors():
def test_backward_valid_chunk_size(chunk_size):
"""Tests that backward works for various valid values of parallel_chunk_size."""

A = UPGrad()
aggregator = UPGrad()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
Expand All @@ -142,7 +142,7 @@ def test_backward_valid_chunk_size(chunk_size):
y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum()
y2 = (p1**2).sum() + p2.norm()

backward([y1, y2], A, parallel_chunk_size=chunk_size, retain_graph=True)
backward([y1, y2], aggregator, parallel_chunk_size=chunk_size, retain_graph=True)

for p in params:
assert (p.grad is not None) and (p.shape == p.grad.shape)
Expand All @@ -152,7 +152,7 @@ def test_backward_valid_chunk_size(chunk_size):
def test_backward_non_positive_chunk_size(chunk_size: int):
"""Tests that backward raises an error when using invalid chunk sizes."""

A = UPGrad()
aggregator = UPGrad()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
Expand All @@ -161,7 +161,7 @@ def test_backward_non_positive_chunk_size(chunk_size: int):
y2 = (p1**2).sum() + p2.norm()

with raises(ValueError):
backward([y1, y2], A, parallel_chunk_size=chunk_size)
backward([y1, y2], aggregator, parallel_chunk_size=chunk_size)


@mark.parametrize(
Expand All @@ -174,7 +174,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation:
large enough to allow differentiation of all tensors at once.
"""

A = UPGrad()
aggregator = UPGrad()

p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE)
p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE)
Expand All @@ -183,7 +183,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation:
y2 = (p1**2).sum() + p2.norm()

with expectation:
backward([y1, y2], A, retain_graph=False, parallel_chunk_size=chunk_size)
backward([y1, y2], aggregator, retain_graph=False, parallel_chunk_size=chunk_size)


def test_backward_fails_with_input_retaining_grad():
Expand All @@ -198,7 +198,7 @@ def test_backward_fails_with_input_retaining_grad():
c = 3 * b

with raises(RuntimeError):
backward(tensors=c, A=UPGrad(), inputs=[b])
backward(tensors=c, aggregator=UPGrad(), inputs=[b])


def test_backward_fails_with_non_input_retaining_grad():
Expand All @@ -213,7 +213,7 @@ def test_backward_fails_with_non_input_retaining_grad():
c = 3 * b

# backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor
backward(tensors=c, A=UPGrad(), inputs=[a])
backward(tensors=c, aggregator=UPGrad(), inputs=[a])

with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
Expand Down
Loading

0 comments on commit a698f68

Please sign in to comment.