Skip to content

Commit fa5adb8

Browse files
committed
update docstring
1 parent f63d074 commit fa5adb8

21 files changed

+3294
-147
lines changed

ConFIG/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#usr/bin/python3
22
# -*- coding: UTF-8 -*-
33
import torch
4-
from typing import Optional,Sequence,Union
4+
from typing import Optional,Sequence,Union,Tuple

ConFIG/grad_operator.py

Lines changed: 118 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,40 @@
1010
def ConFIG_update_double(grad_1:torch.Tensor,grad_2:torch.Tensor,
1111
weight_model:WeightModel=EqualWeight(),
1212
length_model:LengthModel=ProjectionLength(),
13-
losses:Optional[Sequence]=None):
13+
losses:Optional[Sequence]=None)-> torch.Tensor:
1414
"""
1515
ConFIG update for two gradients where no inverse calculation is needed.
1616
1717
Args:
1818
grad_1 (torch.Tensor): The first gradient.
1919
grad_2 (torch.Tensor): The second gradient.
20-
weight_model (WeightModel, optional): The weight model to determine the coefficients. Defaults to EqualWeight().
21-
length_model (LengthModel, optional): The length model to rescale the target vector. Defaults to ProjectionLength().
22-
losses (Optional[Sequence], optional): The losses associated with the gradients. Defaults to None.
20+
weight_model (WeightModel, optional): The weight model for calculating the direction weights.
21+
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
22+
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
23+
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
24+
losses (Optional[Sequence], optional): The losses associated with the gradients.
25+
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
26+
you can set this value as None. Defaults to None.
2327
2428
Returns:
25-
torch.Tensor: The rescaled length of the best direction.
29+
torch.Tensor: The final update gradient.
30+
31+
Examples:
32+
```python
33+
from ConFIG.grad_operator import ConFIG_update_double
34+
from ConFIG.utils import get_gradient_vector,apply_gradient_vector
35+
optimizer=torch.Adam(network.parameters(),lr=1e-3)
36+
for input_i in dataset:
37+
grads=[] # we record gradients rather than losses
38+
for loss_fn in [loss_fn1, loss_fn2]:
39+
optimizer.zero_grad()
40+
loss_i=loss_fn(input_i)
41+
loss_i.backward()
42+
grads.append(get_gradient_vector(network)) #get loss-specfic gradient
43+
g_config=ConFIG_update_double(grads) # calculate the conflict-free direction
44+
apply_gradient_vector(network) # set the condlict-free direction to the network
45+
optimizer.step()
46+
```
2647
2748
"""
2849
with torch.no_grad():
@@ -49,20 +70,42 @@ def ConFIG_update(
4970
weight_model:WeightModel=EqualWeight(),
5071
length_model:LengthModel=ProjectionLength(),
5172
use_latest_square:bool=True,
52-
losses:Optional[Sequence]=None
53-
):
73+
losses:Optional[Sequence]=None)-> torch.Tensor:
5474
"""
5575
Performs the standard ConFIG update step.
5676
5777
Args:
58-
grads (Sequence[torch.Tensor]): The gradients to update.
59-
weight_model (WeightModel, optional): The weight model to use for calculating weights. Defaults to EqualWeight().
60-
length_model (LengthModel, optional): The length model to use for rescaling the length of the target vector. Defaults to ProjectionLength().
61-
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction. Defaults to True.
62-
losses (Optional[Sequence], optional): The losses associated with the gradients. Defaults to None.
78+
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
79+
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
80+
weight_model (WeightModel, optional): The weight model for calculating the direction weights.
81+
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
82+
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
83+
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
84+
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
85+
If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. Recommended to set to True. Defaults to True.
86+
losses (Optional[Sequence], optional): The losses associated with the gradients.
87+
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
88+
you can set this value as None. Defaults to None.
6389
6490
Returns:
65-
torch.Tensor: The rescaled length of the target vector.
91+
torch.Tensor: The final update gradient.
92+
93+
Examples:
94+
```python
95+
from ConFIG.grad_operator import ConFIG_update
96+
from ConFIG.utils import get_gradient_vector,apply_gradient_vector
97+
optimizer=torch.Adam(network.parameters(),lr=1e-3)
98+
for input_i in dataset:
99+
grads=[] # we record gradients rather than losses
100+
for loss_fn in loss_fns:
101+
optimizer.zero_grad()
102+
loss_i=loss_fn(input_i)
103+
loss_i.backward()
104+
grads.append(get_gradient_vector(network)) #get loss-specfic gradient
105+
g_config=ConFIG_update(grads) # calculate the conflict-free direction
106+
apply_gradient_vector(network) # set the condlict-free direction to the network
107+
optimizer.step()
108+
```
66109
"""
67110
if not isinstance(grads,torch.Tensor):
68111
grads=torch.stack(grads)
@@ -80,7 +123,7 @@ def ConFIG_update(
80123

81124
class GradientOperator:
82125
"""
83-
A class that represents a gradient operator.
126+
A base class that represents a gradient operator.
84127
85128
Methods:
86129
calculate_gradient: Calculates the gradient based on the given gradients and losses.
@@ -91,13 +134,16 @@ class GradientOperator:
91134
def __init__(self):
92135
pass
93136

94-
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
137+
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)-> torch.Tensor:
95138
"""
96139
Calculates the gradient based on the given gradients and losses.
97140
98141
Args:
99-
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients.
100-
losses (Optional[Sequence]): The losses (default: None).
142+
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
143+
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
144+
losses (Optional[Sequence], optional): The losses associated with the gradients.
145+
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
146+
you can set this value as None. Defaults to None.
101147
102148
Returns:
103149
torch.Tensor: The calculated gradient.
@@ -108,14 +154,17 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],
108154
"""
109155
raise NotImplementedError("calculate_gradient method must be implemented")
110156

111-
def update_gradient(self, network:torch.nn.Module, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
157+
def update_gradient(self, network: torch.nn.Module, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)-> None:
112158
"""
113-
Updates the gradient of the network based on the calculated gradient.
159+
Calculate the gradient and apply the gradient to the network.
114160
115161
Args:
116-
network: The network.
117-
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients.
118-
losses (Optional[Sequence]): The losses (default: None).
162+
network (torch.nn.Module): The target network.
163+
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
164+
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
165+
losses (Optional[Sequence], optional): The losses associated with the gradients.
166+
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
167+
you can set this value as None. Defaults to None.
119168
120169
Returns:
121170
None
@@ -126,13 +175,36 @@ def update_gradient(self, network:torch.nn.Module, grads: Union[torch.Tensor,Seq
126175

127176
class ConFIGOperator(GradientOperator):
128177
"""
129-
ConFIGOperator class represents a gradient operator for ConFIG algorithm.
178+
Operator for the ConFIG algorithm.
130179
131180
Args:
132-
weight_model (WeightModel, optional): The weight model to be used for calculating the gradient. Defaults to EqualWeight().
133-
length_model (LengthModel, optional): The length model to be used for calculating the gradient. Defaults to ProjectionLength().
134-
allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient. Defaults to True.
135-
use_latest_square (bool, optional): Whether to use the latest square for calculating the gradient. Defaults to True.
181+
weight_model (WeightModel, optional): The weight model for calculating the direction weights.
182+
Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
183+
length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
184+
Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
185+
allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient.
186+
If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True.
187+
use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
188+
If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. Recommended to set to True. Defaults to True.
189+
190+
Examples:
191+
```python
192+
from ConFIG.grad_operator import ConFIGOperator
193+
from ConFIG.utils import get_gradient_vector,apply_gradient_vector
194+
optimizer=torch.Adam(network.parameters(),lr=1e-3)
195+
operator=ConFIGOperator() # initialize operator
196+
for input_i in dataset:
197+
grads=[]
198+
for loss_fn in loss_fns:
199+
optimizer.zero_grad()
200+
loss_i=loss_fn(input_i)
201+
loss_i.backward()
202+
grads.append(get_gradient_vector(network))
203+
g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
204+
apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
205+
optimizer.step()
206+
```
207+
136208
"""
137209

138210
def __init__(self,
@@ -146,13 +218,16 @@ def __init__(self,
146218
self.allow_simplified_model = allow_simplified_model
147219
self.use_latest_square = use_latest_square
148220

149-
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
221+
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)->torch.Tensor:
150222
"""
151223
Calculates the gradient using the ConFIG algorithm.
152224
153225
Args:
154-
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
155-
losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
226+
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
227+
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
228+
losses (Optional[Sequence], optional): The losses associated with the gradients.
229+
The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
230+
you can set this value as None. Defaults to None.
156231
157232
Returns:
158233
torch.Tensor: The calculated gradient.
@@ -186,16 +261,17 @@ class PCGradOperator(GradientOperator):
186261
"""
187262

188263

189-
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
264+
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None)->torch.Tensor:
190265
"""
191-
Calculates the gradient using the ConFIG algorithm.
266+
Calculates the gradient using the PCGrad algorithm.
192267
193268
Args:
194-
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
195-
losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
269+
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
270+
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
271+
losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.
196272
197273
Returns:
198-
torch.Tensor: The calculated gradient.
274+
torch.Tensor: The calculated gradient using PCGrad method.
199275
"""
200276
if not isinstance(grads,torch.Tensor):
201277
grads=torch.stack(grads)
@@ -212,7 +288,7 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],
212288

213289
class IMTLGOperator(GradientOperator):
214290
"""
215-
PCGradOperator class represents a gradient operator for IMTLG algorithm.
291+
PCGradOperator class represents a gradient operator for IMTL-G algorithm.
216292
217293
@inproceedings{
218294
liu2021towards,
@@ -226,16 +302,17 @@ class IMTLGOperator(GradientOperator):
226302
"""
227303

228304

229-
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None):
305+
def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]], losses: Optional[Sequence] = None) ->torch.Tensor:
230306
"""
231-
Calculates the gradient using the ConFIG algorithm.
307+
Calculates the gradient using the IMTL-G algorithm.
232308
233309
Args:
234-
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
235-
losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
310+
grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
311+
It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
312+
losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.
236313
237314
Returns:
238-
torch.Tensor: The calculated gradient.
315+
torch.Tensor: The calculated gradient using IMTL-G method.
239316
"""
240317
if not isinstance(grads,torch.Tensor):
241318
grads=torch.stack(grads)

ConFIG/length_model.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
class LengthModel:
66
"""
7-
This class represents a length model.
7+
The base class for length model.
88
99
Methods:
1010
get_length: Calculates the length based on the given parameters.
@@ -17,35 +17,36 @@ def get_length(self,
1717
target_vector:Optional[torch.Tensor]=None,
1818
unit_target_vector:Optional[torch.Tensor]=None,
1919
gradients:Optional[torch.Tensor]=None,
20-
losses:Optional[Sequence]=None):
20+
losses:Optional[Sequence]=None)-> Union[torch.Tensor, float]:
2121
"""
22-
Calculates the length based on the given parameters.
22+
Calculates the length based on the given parameters. Not all parameters are required.
2323
2424
Args:
25-
target_vector: The target vector.
26-
unit_target_vector: The unit target vector.
27-
gradients: The gradients.
28-
losses: The losses.
25+
target_vector (Optional[torch.Tensor]): The final update gradient vector.
26+
unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
27+
gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
28+
losses (Optional[Sequence]): The losses.
2929
3030
Returns:
31-
The calculated length.
31+
Union[torch.Tensor, float]: The calculated length.
3232
"""
3333
raise NotImplementedError("This method must be implemented by the subclass.")
3434

3535
def rescale_length(self,
3636
target_vector:torch.Tensor,
3737
gradients:Optional[torch.Tensor]=None,
38-
losses:Optional[Sequence]=None):
38+
losses:Optional[Sequence]=None)->torch.Tensor:
3939
"""
4040
Rescales the length of the target vector based on the given parameters.
41-
41+
It calls the get_length method to calculate the length and then rescales the target vector.
42+
4243
Args:
43-
target_vector: The target vector.
44-
gradients: The gradients.
45-
losses: The losses.
46-
44+
target_vector (torch.Tensor): The final update gradient vector.
45+
gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
46+
losses (Optional[Sequence]): The losses.
47+
4748
Returns:
48-
The rescaled length.
49+
torch.Tensor: The rescaled target vector.
4950
"""
5051
unit_target_vector = unit_vector(target_vector)
5152
return self.get_length(target_vector=target_vector,
@@ -64,7 +65,23 @@ def __init__(self):
6465
def get_length(self, target_vector:Optional[torch.Tensor]=None,
6566
unit_target_vector:Optional[torch.Tensor]=None,
6667
gradients:Optional[torch.Tensor]=None,
67-
losses:Optional[Sequence]=None):
68+
losses:Optional[Sequence]=None)->torch.Tensor:
69+
"""
70+
Calculates the length based on the given parameters. Not all parameters are required.
71+
72+
Args:
73+
target_vector (Optional[torch.Tensor]): The final update gradient vector.
74+
One of the `target_vector` or `unit_target_vector` parameter need to be provided.
75+
unit_target_vector (Optional[torch.Tensor]): The unit vector of the target vector.
76+
One of the `target_vector` or `unit_target_vector` parameter need to be provided.
77+
gradients (Optional[torch.Tensor]): The loss-specific gradients matrix.
78+
losses (Optional[Sequence]): The losses. Not used in this model.
79+
80+
Returns:
81+
Union[torch.Tensor, float]: The calculated length.
82+
"""
6883
if gradients is None:
6984
raise ValueError("The ProjectLength model requires gradients information.")
85+
if unit_target_vector is None:
86+
unit_target_vector = unit_vector(target_vector)
7087
return torch.sum(torch.stack([torch.dot(grad_i,unit_target_vector) for grad_i in gradients]))

0 commit comments

Comments
 (0)