Skip to content

Commit

Permalink
Merge pull request #34 from kozistr/feature/dummy
Browse files Browse the repository at this point in the history
[Refactor] Refactor the codes
  • Loading branch information
kozistr authored Oct 6, 2021
2 parents 3934366 + 7599d81 commit 49d3937
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 44 deletions.
2 changes: 1 addition & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
from pytorch_optimizer.sam import SAM
from pytorch_optimizer.sgdp import SGDP

__VERSION__ = '0.0.11'
__VERSION__ = '0.1.0'
81 changes: 40 additions & 41 deletions pytorch_optimizer/pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,41 @@ def set_grad(self, grads):
p.grad = grads[idx]
idx += 1

def pc_backward(self, objectives: Iterable[nn.Module]):
"""Calculate the gradient of the parameters
:param objectives: Iterable[nn.Module]. a list of objectives
def retrieve_grad(self):
"""get the gradient of the parameters of the network with specific objective"""
grad, shape, has_grad = [], [], []
for group in self.optimizer.param_groups:
for p in group['params']:
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
continue

shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))

return grad, shape, has_grad

def pack_grad(self, objectives: Iterable[nn.Module]):
"""pack the gradient of the parameters of the network for each objective
:param objectives: Iterable[float]. a list of objectives
:return:
"""
grads, shapes, has_grads = self.pack_grad(objectives)
pc_grad = self.project_conflicting(grads, has_grads)
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])
self.set_grad(pc_grad)
grads, shapes, has_grads = [], [], []
for objective in objectives:
self.zero_grad()

objective.backward(retain_graph=True)

grad, shape, has_grad = self.retrieve_grad()

grads.append(self.flatten_grad(grad))
has_grads.append(self.flatten_grad(has_grad))
shapes.append(shape)

return grads, shapes, has_grads

def project_conflicting(self, grads, has_grads) -> torch.Tensor:
"""
Expand Down Expand Up @@ -99,40 +125,13 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor:

return merged_grad

def retrieve_grad(self):
"""Get the gradient of the parameters of the network with specific objective
:return:
"""
grad, shape, has_grad = [], [], []
for group in self.optimizer.param_groups:
for p in group['params']:
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
continue

shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))

return grad, shape, has_grad

def pack_grad(self, objectives: Iterable[nn.Module]):
"""Pack the gradient of the parameters of the network for each objective
:param objectives: Iterable[float]. a list of objectives
def pc_backward(self, objectives: Iterable[nn.Module]):
"""calculate the gradient of the parameters
:param objectives: Iterable[nn.Module]. a list of objectives
:return:
"""
grads, shapes, has_grads = [], [], []
for objective in objectives:
self.zero_grad()

objective.backward(retain_graph=True)

grad, shape, has_grad = self.retrieve_grad()

grads.append(self.flatten_grad(grad))
has_grads.append(self.flatten_grad(has_grad))
shapes.append(shape)
grads, shapes, has_grads = self.pack_grad(objectives)
pc_grad = self.project_conflicting(grads, has_grads)
pc_grad = self.un_flatten_grad(pc_grad, shapes[0])

return grads, shapes, has_grads
self.set_grad(pc_grad)
4 changes: 2 additions & 2 deletions pytorch_optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo
return x


def unit_norm(x: torch.Tensor) -> torch.Tensor:
def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor:
keep_dim: bool = True
dim: Optional[Union[int, Tuple[int, ...]]] = None

Expand All @@ -40,4 +40,4 @@ def unit_norm(x: torch.Tensor) -> torch.Tensor:
else:
dim = tuple(range(1, x_len))

return x.norm(dim=dim, keepdim=keep_dim, p=2.0)
return x.norm(dim=dim, keepdim=keep_dim, p=norm)

0 comments on commit 49d3937

Please sign in to comment.