Skip to content

v0.8.0

Choose a tag to compare

@ValerianRey ValerianRey released this 13 Nov 18:54
· 108 commits to main since this release
5deba67

🔥 Autogram: a new engine for Jacobian descent 🔥

After months of hard work, we're happy to release autogram: a new engine to compute the Gramian G = J @ J.T of the Jacobian J of the losses with respect to the model parameters.

Have you ever had memory issues while using TorchJD? Try out this new approach!

This Gramian is computed iteratively, while only having parts of J in memory at a time, so it is much more memory-efficient than computing the full Jacobian and multiplying it by its transpose.

Why does the Gramian of the Jacobian matter?

Most aggregators simply make a weighted combination of the rows of the Jacobian, whose weights depend only on the Gramian of the Jacobian. So while in standard Jacobian descent, you compute the Jacobian J and aggregate it into a vector to update the model, in Gramian-based Jacobian descent, you directly compute the Gramian of the Jacobian, and then extract weights from this Gramian, and backward the weighted combination of the losses.

This is equivalent to standard Jacobian descent, but much more memory efficient, because the Jacobian never has to be fully stored in memory. It's thus also typically much faster, especially for instance-wise risk minimization (IWRM). For more theoretical justifications, please read Section 6 of our paper.

How to make the switch?

Old engine (autojac):

from torchjd.autojac import backward
from torchjd.aggregation import UPGrad

aggregator = UPGrad()

# Repeat for several iterations
output = model(input)
losses = criterion(output, target)
optimizer.zero_grad()
backward(losses, aggregator=aggregator) 
optimizer.step()

New engine (autogram):

from torchjd.autogram import Engine
from torchjd.aggregation import UPGradWeighting

weighting = UPGradWeighting()
engine = Engine(model, batch_dim=0)  # Use batch_dim=None if not doing IWRM.

# Repeat for several iterations
output = model(input)
losses = criterion(output, target)
optimizer.zero_grad()
gramian = engine.compute_gramian(losses)
weights = weighting(gramian)
losses.backward(weights)
optimizer.step()

We're still working on making the engine even faster, but with this release you can already start using it. The interface is likely to change in the future, but adapting to these changes should always be easy!

Please open issues if you run into any problems while using it or if your have suggestions for improvements!

Changelog

Added

  • Added the autogram package, with the autogram.Engine. This is an implementation of Algorithm 3
    from Jacobian Descent for Multi-Objective Optimization,
    optimized for batched computations, as in IWRM. Generalized Gramians can also be obtained by using
    the autogram engine on a tensor of losses of arbitrary shape.
  • For all Aggregators based on the weighting of the Gramian of the Jacobian, made their
    Weighting class public. It can be used directly on a Gramian (computed via the
    autogram.Engine) to extract some weights. The list of new public classes is:
    • Weighting (abstract base class)
    • UPGradWeighting
    • AlignedMTLWeighting
    • CAGradWeighting
    • ConstantWeighting
    • DualProjWeighting
    • IMTLGWeighting
    • KrumWeighting
    • MeanWeighting
    • MGDAWeighting
    • PCGradWeighting
    • RandomWeighting
    • SumWeighting
  • Added GeneralizedWeighting (base class) and Flattening (implementation) to extract tensors of
    weights from generalized Gramians.
  • Added usage example for IWRM with autogram.
  • Added usage example for IWRM with partial autogram.
  • Added usage example for IWMTL with autogram.
  • Added Python 3.14 classifier in pyproject.toml (we now also run tests on Python 3.14 in the CI).

Changed

  • Removed an unnecessary internal reshape when computing Jacobians. This should have no effect but a
    slight performance improvement in autojac.
  • Revamped documentation.
  • Made backward and mtl_backward importable from torchjd.autojac (like it was prior to 0.7.0).
  • Deprecated importing backward and mtl_backward from torchjd directly.