diff --git a/pytorch_optimizer/__init__.py b/pytorch_optimizer/__init__.py index a48ba62ab..216f97b7a 100644 --- a/pytorch_optimizer/__init__.py +++ b/pytorch_optimizer/__init__.py @@ -17,4 +17,4 @@ from pytorch_optimizer.sam import SAM from pytorch_optimizer.sgdp import SGDP -__VERSION__ = '0.0.11' +__VERSION__ = '0.1.0' diff --git a/pytorch_optimizer/pcgrad.py b/pytorch_optimizer/pcgrad.py index a0ee60fe0..362a1b4e9 100644 --- a/pytorch_optimizer/pcgrad.py +++ b/pytorch_optimizer/pcgrad.py @@ -61,15 +61,41 @@ def set_grad(self, grads): p.grad = grads[idx] idx += 1 - def pc_backward(self, objectives: Iterable[nn.Module]): - """Calculate the gradient of the parameters - :param objectives: Iterable[nn.Module]. a list of objectives + def retrieve_grad(self): + """get the gradient of the parameters of the network with specific objective""" + grad, shape, has_grad = [], [], [] + for group in self.optimizer.param_groups: + for p in group['params']: + if p.grad is None: + shape.append(p.shape) + grad.append(torch.zeros_like(p).to(p.device)) + has_grad.append(torch.zeros_like(p).to(p.device)) + continue + + shape.append(p.grad.shape) + grad.append(p.grad.clone()) + has_grad.append(torch.ones_like(p).to(p.device)) + + return grad, shape, has_grad + + def pack_grad(self, objectives: Iterable[nn.Module]): + """pack the gradient of the parameters of the network for each objective + :param objectives: Iterable[float]. a list of objectives :return: """ - grads, shapes, has_grads = self.pack_grad(objectives) - pc_grad = self.project_conflicting(grads, has_grads) - pc_grad = self.un_flatten_grad(pc_grad, shapes[0]) - self.set_grad(pc_grad) + grads, shapes, has_grads = [], [], [] + for objective in objectives: + self.zero_grad() + + objective.backward(retain_graph=True) + + grad, shape, has_grad = self.retrieve_grad() + + grads.append(self.flatten_grad(grad)) + has_grads.append(self.flatten_grad(has_grad)) + shapes.append(shape) + + return grads, shapes, has_grads def project_conflicting(self, grads, has_grads) -> torch.Tensor: """ @@ -99,40 +125,13 @@ def project_conflicting(self, grads, has_grads) -> torch.Tensor: return merged_grad - def retrieve_grad(self): - """Get the gradient of the parameters of the network with specific objective - :return: - """ - grad, shape, has_grad = [], [], [] - for group in self.optimizer.param_groups: - for p in group['params']: - if p.grad is None: - shape.append(p.shape) - grad.append(torch.zeros_like(p).to(p.device)) - has_grad.append(torch.zeros_like(p).to(p.device)) - continue - - shape.append(p.grad.shape) - grad.append(p.grad.clone()) - has_grad.append(torch.ones_like(p).to(p.device)) - - return grad, shape, has_grad - - def pack_grad(self, objectives: Iterable[nn.Module]): - """Pack the gradient of the parameters of the network for each objective - :param objectives: Iterable[float]. a list of objectives + def pc_backward(self, objectives: Iterable[nn.Module]): + """calculate the gradient of the parameters + :param objectives: Iterable[nn.Module]. a list of objectives :return: """ - grads, shapes, has_grads = [], [], [] - for objective in objectives: - self.zero_grad() - - objective.backward(retain_graph=True) - - grad, shape, has_grad = self.retrieve_grad() - - grads.append(self.flatten_grad(grad)) - has_grads.append(self.flatten_grad(has_grad)) - shapes.append(shape) + grads, shapes, has_grads = self.pack_grad(objectives) + pc_grad = self.project_conflicting(grads, has_grads) + pc_grad = self.un_flatten_grad(pc_grad, shapes[0]) - return grads, shapes, has_grads + self.set_grad(pc_grad) diff --git a/pytorch_optimizer/utils.py b/pytorch_optimizer/utils.py index a56b53d98..2ea8dc503 100644 --- a/pytorch_optimizer/utils.py +++ b/pytorch_optimizer/utils.py @@ -26,7 +26,7 @@ def normalize_gradient(x: torch.Tensor, use_channels: bool = False, epsilon: flo return x -def unit_norm(x: torch.Tensor) -> torch.Tensor: +def unit_norm(x: torch.Tensor, norm: float = 2.0) -> torch.Tensor: keep_dim: bool = True dim: Optional[Union[int, Tuple[int, ...]]] = None @@ -40,4 +40,4 @@ def unit_norm(x: torch.Tensor) -> torch.Tensor: else: dim = tuple(range(1, x_len)) - return x.norm(dim=dim, keepdim=keep_dim, p=2.0) + return x.norm(dim=dim, keepdim=keep_dim, p=norm)