10
10
def ConFIG_update_double (grad_1 :torch .Tensor ,grad_2 :torch .Tensor ,
11
11
weight_model :WeightModel = EqualWeight (),
12
12
length_model :LengthModel = ProjectionLength (),
13
- losses :Optional [Sequence ]= None ):
13
+ losses :Optional [Sequence ]= None )-> torch . Tensor :
14
14
"""
15
15
ConFIG update for two gradients where no inverse calculation is needed.
16
16
17
17
Args:
18
18
grad_1 (torch.Tensor): The first gradient.
19
19
grad_2 (torch.Tensor): The second gradient.
20
- weight_model (WeightModel, optional): The weight model to determine the coefficients. Defaults to EqualWeight().
21
- length_model (LengthModel, optional): The length model to rescale the target vector. Defaults to ProjectionLength().
22
- losses (Optional[Sequence], optional): The losses associated with the gradients. Defaults to None.
20
+ weight_model (WeightModel, optional): The weight model for calculating the direction weights.
21
+ Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
22
+ length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
23
+ Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
24
+ losses (Optional[Sequence], optional): The losses associated with the gradients.
25
+ The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
26
+ you can set this value as None. Defaults to None.
23
27
24
28
Returns:
25
- torch.Tensor: The rescaled length of the best direction.
29
+ torch.Tensor: The final update gradient.
30
+
31
+ Examples:
32
+ ```python
33
+ from ConFIG.grad_operator import ConFIG_update_double
34
+ from ConFIG.utils import get_gradient_vector,apply_gradient_vector
35
+ optimizer=torch.Adam(network.parameters(),lr=1e-3)
36
+ for input_i in dataset:
37
+ grads=[] # we record gradients rather than losses
38
+ for loss_fn in [loss_fn1, loss_fn2]:
39
+ optimizer.zero_grad()
40
+ loss_i=loss_fn(input_i)
41
+ loss_i.backward()
42
+ grads.append(get_gradient_vector(network)) #get loss-specfic gradient
43
+ g_config=ConFIG_update_double(grads) # calculate the conflict-free direction
44
+ apply_gradient_vector(network) # set the condlict-free direction to the network
45
+ optimizer.step()
46
+ ```
26
47
27
48
"""
28
49
with torch .no_grad ():
@@ -49,20 +70,42 @@ def ConFIG_update(
49
70
weight_model :WeightModel = EqualWeight (),
50
71
length_model :LengthModel = ProjectionLength (),
51
72
use_latest_square :bool = True ,
52
- losses :Optional [Sequence ]= None
53
- ):
73
+ losses :Optional [Sequence ]= None )-> torch .Tensor :
54
74
"""
55
75
Performs the standard ConFIG update step.
56
76
57
77
Args:
58
- grads (Sequence[torch.Tensor]): The gradients to update.
59
- weight_model (WeightModel, optional): The weight model to use for calculating weights. Defaults to EqualWeight().
60
- length_model (LengthModel, optional): The length model to use for rescaling the length of the target vector. Defaults to ProjectionLength().
61
- use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction. Defaults to True.
62
- losses (Optional[Sequence], optional): The losses associated with the gradients. Defaults to None.
78
+ grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
79
+ It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
80
+ weight_model (WeightModel, optional): The weight model for calculating the direction weights.
81
+ Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
82
+ length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
83
+ Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
84
+ use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
85
+ If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. Recommended to set to True. Defaults to True.
86
+ losses (Optional[Sequence], optional): The losses associated with the gradients.
87
+ The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
88
+ you can set this value as None. Defaults to None.
63
89
64
90
Returns:
65
- torch.Tensor: The rescaled length of the target vector.
91
+ torch.Tensor: The final update gradient.
92
+
93
+ Examples:
94
+ ```python
95
+ from ConFIG.grad_operator import ConFIG_update
96
+ from ConFIG.utils import get_gradient_vector,apply_gradient_vector
97
+ optimizer=torch.Adam(network.parameters(),lr=1e-3)
98
+ for input_i in dataset:
99
+ grads=[] # we record gradients rather than losses
100
+ for loss_fn in loss_fns:
101
+ optimizer.zero_grad()
102
+ loss_i=loss_fn(input_i)
103
+ loss_i.backward()
104
+ grads.append(get_gradient_vector(network)) #get loss-specfic gradient
105
+ g_config=ConFIG_update(grads) # calculate the conflict-free direction
106
+ apply_gradient_vector(network) # set the condlict-free direction to the network
107
+ optimizer.step()
108
+ ```
66
109
"""
67
110
if not isinstance (grads ,torch .Tensor ):
68
111
grads = torch .stack (grads )
@@ -80,7 +123,7 @@ def ConFIG_update(
80
123
81
124
class GradientOperator :
82
125
"""
83
- A class that represents a gradient operator.
126
+ A base class that represents a gradient operator.
84
127
85
128
Methods:
86
129
calculate_gradient: Calculates the gradient based on the given gradients and losses.
@@ -91,13 +134,16 @@ class GradientOperator:
91
134
def __init__ (self ):
92
135
pass
93
136
94
- def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None ):
137
+ def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None )-> torch . Tensor :
95
138
"""
96
139
Calculates the gradient based on the given gradients and losses.
97
140
98
141
Args:
99
- grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients.
100
- losses (Optional[Sequence]): The losses (default: None).
142
+ grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
143
+ It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
144
+ losses (Optional[Sequence], optional): The losses associated with the gradients.
145
+ The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
146
+ you can set this value as None. Defaults to None.
101
147
102
148
Returns:
103
149
torch.Tensor: The calculated gradient.
@@ -108,14 +154,17 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],
108
154
"""
109
155
raise NotImplementedError ("calculate_gradient method must be implemented" )
110
156
111
- def update_gradient (self , network :torch .nn .Module , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None ):
157
+ def update_gradient (self , network : torch .nn .Module , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None )-> None :
112
158
"""
113
- Updates the gradient of the network based on the calculated gradient .
159
+ Calculate the gradient and apply the gradient to the network .
114
160
115
161
Args:
116
- network: The network.
117
- grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients.
118
- losses (Optional[Sequence]): The losses (default: None).
162
+ network (torch.nn.Module): The target network.
163
+ grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
164
+ It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
165
+ losses (Optional[Sequence], optional): The losses associated with the gradients.
166
+ The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
167
+ you can set this value as None. Defaults to None.
119
168
120
169
Returns:
121
170
None
@@ -126,13 +175,36 @@ def update_gradient(self, network:torch.nn.Module, grads: Union[torch.Tensor,Seq
126
175
127
176
class ConFIGOperator (GradientOperator ):
128
177
"""
129
- ConFIGOperator class represents a gradient operator for ConFIG algorithm.
178
+ Operator for the ConFIG algorithm.
130
179
131
180
Args:
132
- weight_model (WeightModel, optional): The weight model to be used for calculating the gradient. Defaults to EqualWeight().
133
- length_model (LengthModel, optional): The length model to be used for calculating the gradient. Defaults to ProjectionLength().
134
- allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient. Defaults to True.
135
- use_latest_square (bool, optional): Whether to use the latest square for calculating the gradient. Defaults to True.
181
+ weight_model (WeightModel, optional): The weight model for calculating the direction weights.
182
+ Defaults to EqualWeight(), which will make the final update gradient not biased towards any gradient.
183
+ length_model (LengthModel, optional): The length model for rescaling the length of the final gradient.
184
+ Defaults to ProjectionLength(), which will project each gradient vector onto the final gradient vector to get the final length.
185
+ allow_simplified_model (bool, optional): Whether to allow simplified model for calculating the gradient.
186
+ If set to True, will use simplified form of ConFIG method when there are only two losses (ConFIG_update_double). Defaults to True.
187
+ use_latest_square (bool, optional): Whether to use the latest square method for calculating the best direction.
188
+ If set to False, we will directly calculate the pseudo-inverse of the gradient matrix. Recommended to set to True. Defaults to True.
189
+
190
+ Examples:
191
+ ```python
192
+ from ConFIG.grad_operator import ConFIGOperator
193
+ from ConFIG.utils import get_gradient_vector,apply_gradient_vector
194
+ optimizer=torch.Adam(network.parameters(),lr=1e-3)
195
+ operator=ConFIGOperator() # initialize operator
196
+ for input_i in dataset:
197
+ grads=[]
198
+ for loss_fn in loss_fns:
199
+ optimizer.zero_grad()
200
+ loss_i=loss_fn(input_i)
201
+ loss_i.backward()
202
+ grads.append(get_gradient_vector(network))
203
+ g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
204
+ apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
205
+ optimizer.step()
206
+ ```
207
+
136
208
"""
137
209
138
210
def __init__ (self ,
@@ -146,13 +218,16 @@ def __init__(self,
146
218
self .allow_simplified_model = allow_simplified_model
147
219
self .use_latest_square = use_latest_square
148
220
149
- def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None ):
221
+ def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None )-> torch . Tensor :
150
222
"""
151
223
Calculates the gradient using the ConFIG algorithm.
152
224
153
225
Args:
154
- grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
155
- losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
226
+ grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
227
+ It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
228
+ losses (Optional[Sequence], optional): The losses associated with the gradients.
229
+ The losses will be passed to the weight and length model. If your weight/length model doesn't require loss information,
230
+ you can set this value as None. Defaults to None.
156
231
157
232
Returns:
158
233
torch.Tensor: The calculated gradient.
@@ -186,16 +261,17 @@ class PCGradOperator(GradientOperator):
186
261
"""
187
262
188
263
189
- def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None ):
264
+ def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None )-> torch . Tensor :
190
265
"""
191
- Calculates the gradient using the ConFIG algorithm.
266
+ Calculates the gradient using the PCGrad algorithm.
192
267
193
268
Args:
194
- grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
195
- losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
269
+ grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
270
+ It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
271
+ losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.
196
272
197
273
Returns:
198
- torch.Tensor: The calculated gradient.
274
+ torch.Tensor: The calculated gradient using PCGrad method .
199
275
"""
200
276
if not isinstance (grads ,torch .Tensor ):
201
277
grads = torch .stack (grads )
@@ -212,7 +288,7 @@ def calculate_gradient(self, grads: Union[torch.Tensor,Sequence[torch.Tensor]],
212
288
213
289
class IMTLGOperator (GradientOperator ):
214
290
"""
215
- PCGradOperator class represents a gradient operator for IMTLG algorithm.
291
+ PCGradOperator class represents a gradient operator for IMTL-G algorithm.
216
292
217
293
@inproceedings{
218
294
liu2021towards,
@@ -226,16 +302,17 @@ class IMTLGOperator(GradientOperator):
226
302
"""
227
303
228
304
229
- def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None ):
305
+ def calculate_gradient (self , grads : Union [torch .Tensor ,Sequence [torch .Tensor ]], losses : Optional [Sequence ] = None ) -> torch . Tensor :
230
306
"""
231
- Calculates the gradient using the ConFIG algorithm.
307
+ Calculates the gradient using the IMTL-G algorithm.
232
308
233
309
Args:
234
- grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to be used for calculating the gradient.
235
- losses (Optional[Sequence], optional): The losses to be used for calculating the gradient. Defaults to None.
310
+ grads (Union[torch.Tensor,Sequence[torch.Tensor]]): The gradients to update.
311
+ It can be a stack of gradient vectors (at dim 0) or a sequence of gradient vectors.
312
+ losses (Optional[Sequence], optional): This parameter should not be set for current operator. Defaults to None.
236
313
237
314
Returns:
238
- torch.Tensor: The calculated gradient.
315
+ torch.Tensor: The calculated gradient using IMTL-G method .
239
316
"""
240
317
if not isinstance (grads ,torch .Tensor ):
241
318
grads = torch .stack (grads )
0 commit comments