diff --git a/CHANGELOG.md b/CHANGELOG.md index 07e4713c..ddadc672 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ changes that do not affect the user. ### Changed +- **BREAKING**: Changed the name of the parameter `A` to `aggregator` in `backward` and + `mtl_backward`. - **BREAKING**: Changed the order of the parameters of `backward` and `mtl_backward` to make it possible to have a default value for `inputs` and for `shared_params` and `tasks_params`, respectively. Usages of `backward` and `mtl_backward` that rely on the order between arguments diff --git a/README.md b/README.md index e2da6ce9..455b9e65 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ params = [ loss_fn = MSELoss() optimizer = SGD(params, lr=0.1) -A = UPGrad() +aggregator = UPGrad() inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task @@ -92,7 +92,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets): loss2 = loss_fn(output2, target2) optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, A=A) + mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() ``` diff --git a/docs/source/examples/basic_usage.rst b/docs/source/examples/basic_usage.rst index 8e87960d..91af7253 100644 --- a/docs/source/examples/basic_usage.rst +++ b/docs/source/examples/basic_usage.rst @@ -32,7 +32,7 @@ Define the aggregator that will be used to combine the Jacobian matrix: .. code-block:: python - A = UPGrad() + aggregator = UPGrad() In essence, :doc:`UPGrad <../docs/aggregation/upgrad>` projects each gradient onto the dual cone of the rows of the Jacobian and averages the results. This ensures that locally, no loss will be @@ -69,7 +69,7 @@ Perform the Jacobian descent backward pass: .. code-block:: python - torchjd.backward([loss1, loss2], A) + torchjd.backward([loss1, loss2], aggregator) This will populate the ``.grad`` field of each model parameter with the corresponding aggregated Jacobian matrix. diff --git a/docs/source/examples/iwrm.rst b/docs/source/examples/iwrm.rst index 983b7762..65947695 100644 --- a/docs/source/examples/iwrm.rst +++ b/docs/source/examples/iwrm.rst @@ -92,13 +92,13 @@ each Jacobian matrix consists of one gradient per loss. In this example, we use params = model.parameters() optimizer = SGD(params, lr=0.1) - A = UPGrad() + aggregator = UPGrad() for x, y in zip(X, Y): y_hat = model(x) losses = loss_fn(y_hat, y) optimizer.zero_grad() - backward(losses, A) + backward(losses, aggregator) optimizer.step() Note that in both cases, we use the `torch.optim.SGD diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index 41af3d2d..4bf25891 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -44,7 +44,7 @@ The following code example demonstrates a basic multi-task learning setup using opt = self.optimizers() opt.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, A=UPGrad()) + mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) opt.step() def configure_optimizers(self) -> OptimizerLRScheduler: diff --git a/docs/source/examples/mtl.rst b/docs/source/examples/mtl.rst index 7b12be28..17a747f7 100644 --- a/docs/source/examples/mtl.rst +++ b/docs/source/examples/mtl.rst @@ -39,7 +39,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss_fn = MSELoss() optimizer = SGD(params, lr=0.1) - A = UPGrad() + aggregator = UPGrad() inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task @@ -53,7 +53,7 @@ vectors of dimension 10, and their corresponding scalar labels for both tasks. loss2 = loss_fn(output2, target2) optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, A=A) + mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() .. note:: diff --git a/src/torchjd/autojac/backward.py b/src/torchjd/autojac/backward.py index 628643a5..da62a789 100644 --- a/src/torchjd/autojac/backward.py +++ b/src/torchjd/autojac/backward.py @@ -15,18 +15,19 @@ def backward( tensors: Sequence[Tensor] | Tensor, - A: Aggregator, + aggregator: Aggregator, inputs: Iterable[Tensor] | None = None, retain_graph: bool = False, parallel_chunk_size: int | None = None, ) -> None: r""" Computes the Jacobian of all values in ``tensors`` with respect to all ``inputs``. Computes its - aggregation by ``A`` and accumulates it in the ``.grad`` fields of the ``inputs``. + aggregation by the provided ``aggregator`` and accumulates it in the ``.grad`` fields of the + ``inputs``. :param tensors: The tensor or tensors to differentiate. Should be non-empty. The Jacobian matrices will have one row for each value of each of these tensors. - :param A: Aggregator used to reduce the Jacobian into a vector. + :param aggregator: Aggregator used to reduce the Jacobian into a vector. :param inputs: The tensors with respect to which the Jacobian must be computed. These must have their ``requires_grad`` flag set to ``True``. If not provided, defaults to the leaf tensors that were used to compute the ``tensors`` parameter. @@ -95,7 +96,7 @@ def backward( jac = Jac(tensors, inputs, parallel_chunk_size, retain_graph) # Transform that aggregates the Jacobians. - aggregate = Aggregate(A, inputs) + aggregate = Aggregate(aggregator, inputs) # Transform that accumulates the result in the .grad field of the inputs. accumulate = Accumulate(inputs) diff --git a/src/torchjd/autojac/mtl_backward.py b/src/torchjd/autojac/mtl_backward.py index 4c995cbb..f75f1cc4 100644 --- a/src/torchjd/autojac/mtl_backward.py +++ b/src/torchjd/autojac/mtl_backward.py @@ -27,7 +27,7 @@ def mtl_backward( losses: Sequence[Tensor], features: Sequence[Tensor] | Tensor, - A: Aggregator, + aggregator: Aggregator, tasks_params: Sequence[Iterable[Tensor]] | None = None, shared_params: Iterable[Tensor] | None = None, retain_graph: bool = False, @@ -45,7 +45,7 @@ def mtl_backward( :param losses: The task losses. The Jacobian matrix will have one row per loss. :param features: The last shared representation used for all tasks, as given by the feature extractor. Should be non-empty. - :param A: Aggregator used to reduce the Jacobian into a vector. + :param aggregator: Aggregator used to reduce the Jacobian into a vector. :param tasks_params: The parameters of each task-specific head. Their ``requires_grad`` flags must be set to ``True``. If not provided, the parameters considered for each task will default to the leaf tensors that are in the computation graph of its loss, but that were not @@ -129,7 +129,7 @@ def mtl_backward( jac = Jac(features, shared_params, parallel_chunk_size, retain_graph) # Transform that aggregates the Jacobians. - aggregate = Aggregate(A, shared_params) + aggregate = Aggregate(aggregator, shared_params) # Transform that accumulates the result in the .grad field of the shared parameters. accumulate = Accumulate(shared_params) diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 389f8b46..fcb26da2 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -9,7 +9,7 @@ def test_basic_usage(): model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2)) optimizer = SGD(model.parameters(), lr=0.1) - A = UPGrad() + aggregator = UPGrad() input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10 target1 = torch.randn(16) # First batch of 16 targets target2 = torch.randn(16) # Second batch of 16 targets @@ -20,7 +20,7 @@ def test_basic_usage(): loss2 = loss_fn(output[:, 1], target2) optimizer.zero_grad() - torchjd.backward([loss1, loss2], A) + torchjd.backward([loss1, loss2], aggregator) optimizer.step() @@ -62,13 +62,13 @@ def test_iwrm_with_ssjd(): params = model.parameters() optimizer = SGD(params, lr=0.1) - A = UPGrad() + aggregator = UPGrad() for x, y in zip(X, Y): y_hat = model(x) losses = loss_fn(y_hat, y) optimizer.zero_grad() - backward(losses, A) + backward(losses, aggregator) optimizer.step() test_erm_with_sgd() @@ -94,7 +94,7 @@ def test_mtl(): loss_fn = MSELoss() optimizer = SGD(params, lr=0.1) - A = UPGrad() + aggregator = UPGrad() inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task @@ -108,7 +108,7 @@ def test_mtl(): loss2 = loss_fn(output2, target2) optimizer.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, A=A) + mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) optimizer.step() @@ -150,7 +150,7 @@ def training_step(self, batch, batch_idx) -> None: opt = self.optimizers() opt.zero_grad() - mtl_backward(losses=[loss1, loss2], features=features, A=UPGrad()) + mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad()) opt.step() def configure_optimizers(self) -> OptimizerLRScheduler: diff --git a/tests/unit/autojac/test_backward.py b/tests/unit/autojac/test_backward.py index a5ad03f3..440a32a6 100644 --- a/tests/unit/autojac/test_backward.py +++ b/tests/unit/autojac/test_backward.py @@ -10,8 +10,8 @@ from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad -@mark.parametrize("A", [Mean(), UPGrad(), MGDA(), Random()]) -def test_backward_various_aggregators(A: Aggregator): +@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) +def test_backward_various_aggregators(aggregator: Aggregator): """Tests that backward works for various aggregators.""" p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) @@ -21,17 +21,17 @@ def test_backward_various_aggregators(A: Aggregator): y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum() y2 = (p1**2).sum() + p2.norm() - backward([y1, y2], A) + backward([y1, y2], aggregator) for p in params: assert (p.grad is not None) and (p.shape == p.grad.shape) -@mark.parametrize("A", [Mean(), UPGrad(), MGDA()]) +@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()]) @mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)]) @mark.parametrize("manually_specify_inputs", [True, False]) def test_backward_value_is_correct( - A: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool + aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool ): """ Tests that the .grad value filled by backward is correct in a simple example of matrix-vector @@ -47,15 +47,15 @@ def test_backward_value_is_correct( else: inputs = None - backward([output], A, inputs=inputs) + backward([output], aggregator, inputs=inputs) - assert_close(input.grad, A(J)) + assert_close(input.grad, aggregator(J)) def test_backward_empty_inputs(): """Tests that backward does not fill the .grad values if no input is specified.""" - A = Mean() + aggregator = Mean() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) @@ -64,7 +64,7 @@ def test_backward_empty_inputs(): y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum() y2 = (p1**2).sum() + p2.norm() - backward([y1, y2], A, inputs=[]) + backward([y1, y2], aggregator, inputs=[]) for p in params: assert p.grad is None @@ -76,7 +76,7 @@ def test_backward_partial_inputs(): specified as inputs. """ - A = Mean() + aggregator = Mean() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) @@ -84,7 +84,7 @@ def test_backward_partial_inputs(): y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum() y2 = (p1**2).sum() + p2.norm() - backward([y1, y2], A, inputs=[p1]) + backward([y1, y2], aggregator, inputs=[p1]) assert (p1.grad is not None) and (p1.shape == p1.grad.shape) assert p2.grad is None @@ -93,13 +93,13 @@ def test_backward_partial_inputs(): def test_backward_empty_tensors(): """Tests that backward raises an error when called with an empty list of tensors.""" - A = UPGrad() + aggregator = UPGrad() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) with raises(ValueError): - backward([], A, inputs=[p1, p2]) + backward([], aggregator, inputs=[p1, p2]) def test_backward_multiple_tensors(): @@ -108,7 +108,7 @@ def test_backward_multiple_tensors(): containing the all the values of the original tensors. """ - A = UPGrad() + aggregator = UPGrad() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) @@ -117,13 +117,13 @@ def test_backward_multiple_tensors(): y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum() y2 = (p1**2).sum() + p2.norm() - backward([y1, y2], A, retain_graph=True) + backward([y1, y2], aggregator, retain_graph=True) param_to_grad = {p: p.grad for p in params} for p in params: p.grad = None - backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), A) + backward(torch.cat([y1.reshape(-1), y2.reshape(-1)]), aggregator) for p in params: assert (p.grad == param_to_grad[p]).all() @@ -133,7 +133,7 @@ def test_backward_multiple_tensors(): def test_backward_valid_chunk_size(chunk_size): """Tests that backward works for various valid values of parallel_chunk_size.""" - A = UPGrad() + aggregator = UPGrad() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) @@ -142,7 +142,7 @@ def test_backward_valid_chunk_size(chunk_size): y1 = torch.tensor([-1.0, 1.0], device=DEVICE) @ p1 + p2.sum() y2 = (p1**2).sum() + p2.norm() - backward([y1, y2], A, parallel_chunk_size=chunk_size, retain_graph=True) + backward([y1, y2], aggregator, parallel_chunk_size=chunk_size, retain_graph=True) for p in params: assert (p.grad is not None) and (p.shape == p.grad.shape) @@ -152,7 +152,7 @@ def test_backward_valid_chunk_size(chunk_size): def test_backward_non_positive_chunk_size(chunk_size: int): """Tests that backward raises an error when using invalid chunk sizes.""" - A = UPGrad() + aggregator = UPGrad() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) @@ -161,7 +161,7 @@ def test_backward_non_positive_chunk_size(chunk_size: int): y2 = (p1**2).sum() + p2.norm() with raises(ValueError): - backward([y1, y2], A, parallel_chunk_size=chunk_size) + backward([y1, y2], aggregator, parallel_chunk_size=chunk_size) @mark.parametrize( @@ -174,7 +174,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation: large enough to allow differentiation of all tensors at once. """ - A = UPGrad() + aggregator = UPGrad() p1 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) p2 = torch.tensor([3.0, 4.0], requires_grad=True, device=DEVICE) @@ -183,7 +183,7 @@ def test_backward_no_retain_graph_small_chunk_size(chunk_size: int, expectation: y2 = (p1**2).sum() + p2.norm() with expectation: - backward([y1, y2], A, retain_graph=False, parallel_chunk_size=chunk_size) + backward([y1, y2], aggregator, retain_graph=False, parallel_chunk_size=chunk_size) def test_backward_fails_with_input_retaining_grad(): @@ -198,7 +198,7 @@ def test_backward_fails_with_input_retaining_grad(): c = 3 * b with raises(RuntimeError): - backward(tensors=c, A=UPGrad(), inputs=[b]) + backward(tensors=c, aggregator=UPGrad(), inputs=[b]) def test_backward_fails_with_non_input_retaining_grad(): @@ -213,7 +213,7 @@ def test_backward_fails_with_non_input_retaining_grad(): c = 3 * b # backward itself doesn't raise the error, but it fills b.grad with a BatchedTensor - backward(tensors=c, A=UPGrad(), inputs=[a]) + backward(tensors=c, aggregator=UPGrad(), inputs=[a]) with raises(RuntimeError): # Using such a BatchedTensor should result in an error diff --git a/tests/unit/autojac/test_mtl_backward.py b/tests/unit/autojac/test_mtl_backward.py index d1727ea5..9e00b627 100644 --- a/tests/unit/autojac/test_mtl_backward.py +++ b/tests/unit/autojac/test_mtl_backward.py @@ -10,8 +10,8 @@ from torchjd.aggregation import MGDA, Aggregator, Mean, Random, UPGrad -@mark.parametrize("A", [Mean(), UPGrad(), MGDA(), Random()]) -def test_mtl_backward_various_aggregators(A: Aggregator): +@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA(), Random()]) +def test_mtl_backward_various_aggregators(aggregator: Aggregator): """Tests that mtl_backward works for various aggregators.""" p0 = torch.tensor([1.0, 2.0], requires_grad=True, device=DEVICE) @@ -23,18 +23,18 @@ def test_mtl_backward_various_aggregators(A: Aggregator): y1 = r1 * p1[0] + r2 * p1[1] y2 = r1 * p2[0] + r2 * p2[1] - mtl_backward(losses=[y1, y2], features=[r1, r2], A=A) + mtl_backward(losses=[y1, y2], features=[r1, r2], aggregator=aggregator) for p in [p0, p1, p2]: assert (p.grad is not None) and (p.shape == p.grad.shape) -@mark.parametrize("A", [Mean(), UPGrad(), MGDA()]) +@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()]) @mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)]) @mark.parametrize("manually_specify_shared_params", [True, False]) @mark.parametrize("manually_specify_tasks_params", [True, False]) def test_mtl_backward_value_is_correct( - A: Aggregator, + aggregator: Aggregator, shape: tuple[int, int], manually_specify_shared_params: bool, manually_specify_tasks_params: bool, @@ -71,7 +71,7 @@ def test_mtl_backward_value_is_correct( mtl_backward( losses=[y1, y2, y3], features=r, - A=A, + aggregator=aggregator, tasks_params=tasks_params, shared_params=shared_params, ) @@ -81,7 +81,7 @@ def test_mtl_backward_value_is_correct( assert_close(p3.grad, r) expected_jacobian = torch.stack((p1, p2, p3)) @ J - expected_aggregation = A(expected_jacobian) + expected_aggregation = aggregator(expected_jacobian) assert_close(p0.grad, expected_aggregation) @@ -95,7 +95,7 @@ def test_mtl_backward_empty_tasks(): r2 = (p0**2).sum() + p0.norm() with raises(ValueError): - mtl_backward(losses=[], features=[r1, r2], A=UPGrad()) + mtl_backward(losses=[], features=[r1, r2], aggregator=UPGrad()) def test_mtl_backward_single_task(): @@ -108,7 +108,7 @@ def test_mtl_backward_single_task(): r2 = (p0**2).sum() + p0.norm() y1 = r1 * p1[0] + r2 * p1[1] - mtl_backward(losses=[y1], features=[r1, r2], A=UPGrad()) + mtl_backward(losses=[y1], features=[r1, r2], aggregator=UPGrad()) for p in [p0, p1]: assert (p.grad is not None) and (p.shape == p.grad.shape) @@ -133,7 +133,7 @@ def test_mtl_backward_incoherent_task_number(): mtl_backward( losses=[y1, y2], features=[r1, r2], - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[p1]], # Wrong shared_params=[p0], ) @@ -141,7 +141,7 @@ def test_mtl_backward_incoherent_task_number(): mtl_backward( losses=[y1], # Wrong features=[r1, r2], - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[p0], ) @@ -162,7 +162,7 @@ def test_mtl_backward_empty_params(): mtl_backward( losses=[y1, y2], features=[r1, r2], - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[], []], shared_params=[], ) @@ -186,7 +186,7 @@ def test_mtl_backward_multiple_params_per_task(): y1 = r1 * p1_a + (r2 * p1_b).sum() + (r1 * p1_c).sum() y2 = r1 * p2_a * (r2 * p2_b).sum() - mtl_backward(losses=[y1, y2], features=[r1, r2], A=UPGrad()) + mtl_backward(losses=[y1, y2], features=[r1, r2], aggregator=UPGrad()) for p in [p0, p1_a, p1_b, p1_c, p2_a, p2_b]: assert (p.grad is not None) and (p.shape == p.grad.shape) @@ -221,7 +221,7 @@ def test_mtl_backward_various_shared_params(shared_params_shapes: list[tuple[int mtl_backward( losses=[y1, y2], features=representations, - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[p1], [p2]], # Enforce differentiation w.r.t. params that haven't been used shared_params=shared_params, ) @@ -248,7 +248,7 @@ def test_mtl_backward_partial_params(): mtl_backward( losses=[y1, y2], features=[r1, r2], - A=Mean(), + aggregator=Mean(), tasks_params=[[p1], []], shared_params=[p0], ) @@ -271,7 +271,7 @@ def test_mtl_backward_empty_features(): y2 = r1 * p2[0] + r2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[], A=UPGrad()) + mtl_backward(losses=[y1, y2], features=[], aggregator=UPGrad()) @mark.parametrize( @@ -295,7 +295,7 @@ def test_mtl_backward_various_single_features(shape: tuple[int, ...]): y1 = (r * p1[0]).sum() + (r * p1[1]).sum() y2 = (r * p2[0]).sum() * (r * p2[1]).sum() - mtl_backward(losses=[y1, y2], features=r, A=UPGrad()) + mtl_backward(losses=[y1, y2], features=r, aggregator=UPGrad()) for p in [p0, p1, p2]: assert (p.grad is not None) and (p.shape == p.grad.shape) @@ -326,7 +326,7 @@ def test_mtl_backward_various_feature_lists(shapes: list[tuple[int]]): y1 = sum([(r * p).sum() for r, p in zip(representations, p1)]) y2 = (representations[0] * p2).sum() - mtl_backward(losses=[y1, y2], features=representations, A=UPGrad()) + mtl_backward(losses=[y1, y2], features=representations, aggregator=UPGrad()) for p in [p0, p1, p2]: assert (p.grad is not None) and (p.shape == p.grad.shape) @@ -345,7 +345,7 @@ def test_mtl_backward_non_scalar_loss(): y2 = r1 * p2[0] + r2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[r1, r2], A=UPGrad()) + mtl_backward(losses=[y1, y2], features=[r1, r2], aggregator=UPGrad()) @mark.parametrize("chunk_size", [None, 1, 2, 4]) @@ -364,7 +364,7 @@ def test_mtl_backward_valid_chunk_size(chunk_size): mtl_backward( losses=[y1, y2], features=[r1, r2], - A=UPGrad(), + aggregator=UPGrad(), retain_graph=True, parallel_chunk_size=chunk_size, ) @@ -387,7 +387,12 @@ def test_mtl_backward_non_positive_chunk_size(chunk_size: int): y2 = r1 * p2[0] + r2 * p2[1] with raises(ValueError): - mtl_backward(losses=[y1, y2], features=[r1, r2], A=UPGrad(), parallel_chunk_size=chunk_size) + mtl_backward( + losses=[y1, y2], + features=[r1, r2], + aggregator=UPGrad(), + parallel_chunk_size=chunk_size, + ) @mark.parametrize( @@ -415,7 +420,7 @@ def test_mtl_backward_no_retain_graph_small_chunk_size( mtl_backward( losses=[y1, y2], features=[r1, r2], - A=UPGrad(), + aggregator=UPGrad(), retain_graph=False, parallel_chunk_size=chunk_size, ) @@ -441,7 +446,7 @@ def test_mtl_backward_fails_with_shared_param_retaining_grad(): mtl_backward( losses=[y1, y2], features=[features], - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[a, p0], ) @@ -467,7 +472,7 @@ def test_mtl_backward_fails_with_shared_activation_retaining_grad(): mtl_backward( losses=[y1, y2], features=[features], - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[p1], [p2]], shared_params=[p0], ) @@ -489,15 +494,15 @@ def test_mtl_backward_task_params_have_some_overlap(): y1 = r * p1 * p12 y2 = r * p2 * p12 - A = UPGrad() - mtl_backward(losses=[y1, y2], features=[r], A=A, retain_graph=True) + aggregator = UPGrad() + mtl_backward(losses=[y1, y2], features=[r], aggregator=aggregator, retain_graph=True) assert_close(p2.grad, r * p12) assert_close(p1.grad, r * p12) assert_close(p12.grad, r * p1 + r * p2) J = torch.tensor([[-p1 * p12, p1 * p12], [-p2 * p12, p2 * p12]], device=DEVICE) - assert_close(p0.grad, A(J)) + assert_close(p0.grad, aggregator(J)) def test_mtl_backward_task_params_are_the_same(): @@ -510,13 +515,13 @@ def test_mtl_backward_task_params_are_the_same(): y1 = r * p1 y2 = r + p1 - A = UPGrad() - mtl_backward(losses=[y1, y2], features=[r], A=A, retain_graph=True) + aggregator = UPGrad() + mtl_backward(losses=[y1, y2], features=[r], aggregator=aggregator, retain_graph=True) assert_close(p1.grad, r + 1) J = torch.tensor([[-p1, p1], [-1.0, 1.0]], device=DEVICE) - assert_close(p0.grad, A(J)) + assert_close(p0.grad, aggregator(J)) def test_mtl_backward_task_params_are_subset_of_other_task_params(): @@ -533,14 +538,14 @@ def test_mtl_backward_task_params_are_subset_of_other_task_params(): y1 = r * p1 y2 = y1 * p2 - A = UPGrad() - mtl_backward(losses=[y1, y2], features=[r], A=A, retain_graph=True) + aggregator = UPGrad() + mtl_backward(losses=[y1, y2], features=[r], aggregator=aggregator, retain_graph=True) assert_close(p2.grad, y1) assert_close(p1.grad, p2 * r + r) J = torch.tensor([[-p1, p1], [-p1 * p2, p1 * p2]], device=DEVICE) - assert_close(p0.grad, A(J)) + assert_close(p0.grad, aggregator(J)) def test_mtl_backward_shared_params_overlap_with_tasks_params(): @@ -561,7 +566,7 @@ def test_mtl_backward_shared_params_overlap_with_tasks_params(): mtl_backward( losses=[y1, y2], features=[r], - A=UPGrad(), + aggregator=UPGrad(), tasks_params=[[p1], [p0, p2]], # Problem: p0 is also shared shared_params=[p0], retain_graph=True, @@ -586,6 +591,6 @@ def test_mtl_backward_default_shared_params_overlap_with_default_tasks_params(): mtl_backward( losses=[y1, y2], features=[r], - A=UPGrad(), + aggregator=UPGrad(), retain_graph=True, )