You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
g_config=ConFIG_update(grads) # calculate the conflict-free direction
122
-
apply_gradient_vector(network) # set the condlict-free direction to the network
122
+
apply_gradient_vector(network,g_config) # set the condlict-free direction to the network
123
123
optimizer.step()
124
124
```
125
125
"""
@@ -231,7 +231,7 @@ class ConFIGOperator(GradientOperator):
231
231
loss_i.backward()
232
232
grads.append(get_gradient_vector(network))
233
233
g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
234
-
apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
234
+
apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
Copy file name to clipboardExpand all lines: conflictfree/momentum_operator.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -171,7 +171,7 @@ class PseudoMomentumOperator(MomentumOperator):
171
171
loss_i.backward()
172
172
grads.append(get_gradient_vector(network))
173
173
g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
174
-
apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
174
+
apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
g_config=ConFIG_update(grads) # calculate the conflict-free direction
39
-
apply_gradient_vector(network) # set the condlict-free direction to the network
39
+
apply_gradient_vector(network,g_config) # set the condlict-free direction to the network
40
40
optimizer.step()
41
41
```
42
42
@@ -55,7 +55,7 @@ for input_i in dataset:
55
55
loss_i.backward()
56
56
grads.append(get_gradient_vector(network))
57
57
g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
58
-
apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
58
+
apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
59
59
optimizer.step()
60
60
```
61
61
@@ -78,7 +78,7 @@ for input_i in dataset:
78
78
loss_i.backward()
79
79
grads.append(get_gradient_vector(network))
80
80
g_config=operator.calculate_gradient(grads) # calculate the conflict-free direction
81
-
apply_gradient_vector(network) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
81
+
apply_gradient_vector(network,g_config) # or simply use `operator.update_gradient(network,grads)` to calculate and set the condlict-free direction to the network
0 commit comments