diff --git a/README.md b/README.md index ca8df8ba..277d0940 100644 --- a/README.md +++ b/README.md @@ -56,24 +56,38 @@ Some aggregators may have additional dependencies. Please refer to the [installation documentation](https://torchjd.org/stable/installation) for them. ## Usage -There are two main ways to use TorchJD. The first one is to replace the usual call to + +Compared to standard `torch`, `torchjd` simply changes the way to obtain the `.grad` fields of your +model parameters. + +### Using the `autojac` engine + +The autojac engine is for computing and aggregating Jacobians efficiently. + +#### 1. `backward` + `jac_to_grad` +In standard `torch`, you generally combine your `losses` into a single scalar `loss`, and call +`loss.backward()` to compute the gradient of the loss with respect to each model parameter and to +store it in the `.grad` fields of those parameters. The basic usage of `torchjd` is to replace this `loss.backward()` by a call to -[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or -[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending -on the use-case. This will compute the Jacobian of the vector of losses with respect to the model -parameters, and aggregate it with the specified -[`Aggregator`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator). -Whenever you want to optimize the vector of per-sample losses, you should rather use the -[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine/). Instead of -computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a -memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from -this Gramian, using a -[`Weighting`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting), -and used to combine the losses of the batch. Assuming each element of the batch is -processed independently from the others, this approach is equivalent to -[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being -generally much faster due to the lower memory usage. Note that we're still working on making -`autogram` faster and more memory-efficient, and it's interface may change in future releases. +[`torchjd.autojac.backward(losses)`](https://torchjd.org/stable/docs/autojac/backward/). Instead of +computing the gradient of a scalar loss, it will compute the Jacobian of a vector of losses, and +store it in the `.jac` fields of the model parameters. You then have to call +[`torchjd.autojac.jac_to_grad`](https://torchjd.org/stable/docs/autojac/jac_to_grad/) to aggregate +this Jacobian using the specified +[`Aggregator`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Aggregator), and to +store the result into the `.grad` fields of the model parameters. See this +[usage example](https://torchjd.org/stable/examples/basic_usage/) for more details. + +#### 2. `mtl_backward` + `jac_to_grad` +In the case of multi-task learning, an alternative to +[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) is +[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/). It computes +the gradient of each task-specific loss with respect to the corresponding task's parameters, and +stores it in their `.grad` fields. It also computes the Jacobian of the vector of losses with +respect to the shared parameters and stores it in their `.jac` field. Then, the +[`torchjd.autojac.jac_to_grad`](https://torchjd.org/stable/docs/autojac/jac_to_grad/) function can +be called to aggregate this Jacobian and replace the `.jac` fields by `.grad` fields for the shared +parameters. The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). @@ -83,7 +97,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD -+ from torchjd.autojac import mtl_backward ++ from torchjd.autojac import jac_to_grad, mtl_backward + from torchjd.aggregation import UPGrad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) @@ -112,7 +126,8 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). - loss = loss1 + loss2 - loss.backward() -+ mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator) ++ mtl_backward([loss1, loss2], features=features) ++ jac_to_grad(shared_module.parameters(), aggregator) optimizer.step() optimizer.zero_grad() ``` @@ -121,8 +136,42 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). > In this example, the Jacobian is only with respect to the shared parameters. The task-specific > parameters are simply updated via the gradient of their task’s loss with respect to them. -The following example shows how to use TorchJD to minimize the vector of per-instance losses with -Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). +> [!TIP] +> Once your model parameters all have a `.grad` field, it's the role of the +> [optimizer](https://docs.pytorch.org/docs/stable/optim.html#torch.optim.Optimizer) to update the +> parameters values. This is exactly the same as in standard `torch`. + +#### 3. `jac` + +If you're simply interested in computing Jacobians without storing them in the `.jac` fields, you +can also use the [`torchjd.autojac.jac`](https://torchjd.org/stable/docs/autojac/jac/) function, +that is analog to +[`torch.autograd.grad`](https://docs.pytorch.org/docs/stable/generated/torch.autograd.grad.html), +except that it computes the Jacobian of a vector of losses rather than the gradient of a scalar +loss. + +### Using the `autogram` engine + +The Gramian of the Jacobian, defined as the Jacobian multiplied by its transpose, contains all the +dot products between individual gradients. It thus contains all the information about conflict and +gradient imbalance. It turns out that most aggregators from the literature +(e.g. [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/)) make a linear combination of +the rows of the Jacobian, whose weights only depend on the Gramian of the Jacobian. + +An alternative implementation of Jacobian descent is thus to: +- Compute this Gramian incrementally (layer by layer), without ever storing the full Jacobian in + memory. +- Extract the weights from it using a + [`Weighting`](https://torchjd.org/stable/docs/aggregation#torchjd.aggregation.Weighting). +- Combine the losses using those weights and make a step of gradient descent on the combined loss. + +The main advantage of this approach is to save memory because the Jacobian (that is typically large) +never has to be stored in memory. The +[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine/) is precisely made to +compute the Gramian of the Jacobian efficiently. + +The following example shows how to use the `autogram` engine to minimize the vector of per-instance +losses with Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). ```diff import torch @@ -157,8 +206,8 @@ Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgr optimizer.zero_grad() ``` -Lastly, you can even combine the two approaches by considering multiple tasks and each element of -the batch independently. We call that Instance-Wise Multitask Learning (IWMTL). +You can even go one step further by considering the multiple tasks and each element of the batch +independently. We call that Instance-Wise Multitask Learning (IWMTL). ```python import torch @@ -207,7 +256,7 @@ for input, target1, target2 in zip(inputs, task1_targets, task2_targets): ``` > [!NOTE] -> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized +> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized > Gramian* and we extract weights from it using a > [GeneralizedWeighting](https://torchjd.org/stable/docs/aggregation/#torchjd.aggregation.GeneralizedWeighting).