From a4b9348f5789e44e34453f78a227853941f3fa65 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Thu, 31 Mar 2022 20:16:39 +0800 Subject: [PATCH 01/11] support ada module and training --- mmgen/core/evaluation/metrics.py | 17 + .../architectures/stylegan/ada/augment.py | 771 ++++++++++++++++++ .../stylegan/ada/grid_sample_gradfix.py | 108 +++ .../models/architectures/stylegan/ada/misc.py | 31 + .../architectures/stylegan/ada/upfirdn2d.py | 193 +++++ .../stylegan/generator_discriminator_v2.py | 100 ++- .../stylegan/modules/styleganv3_modules.py | 33 +- mmgen/models/gans/static_unconditional_gan.py | 11 + 8 files changed, 1256 insertions(+), 8 deletions(-) create mode 100644 mmgen/models/architectures/stylegan/ada/augment.py create mode 100644 mmgen/models/architectures/stylegan/ada/grid_sample_gradfix.py create mode 100644 mmgen/models/architectures/stylegan/ada/misc.py create mode 100644 mmgen/models/architectures/stylegan/ada/upfirdn2d.py diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index ba7802bda..9637fdf3e 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -953,6 +953,13 @@ def summary(self): return self._result_dict def extract_features(self, images): + """Extracting image features. + + Args: + images (torch.Tensor): Images tensor. + Returns: + torch.Tensor: Vgg16 features of input images. + """ if self.use_tero_scirpt: feature = self.vgg16(images, return_features=True) else: @@ -1278,6 +1285,16 @@ def summary(self): return ppl_score def get_sampler(self, model, batch_size, sample_model): + """Get sampler for sampling along the path. + + Args: + model (nn.Module): Generative model. + batch_size (int): Sampling batch size. + sample_model (str): Which model you want to use. ['ema', + 'orig']. Defaults to 'ema'. + Returns: + Object: A sampler for calculating path length regularization. + """ if sample_model == 'ema': generator = model.generator_ema else: diff --git a/mmgen/models/architectures/stylegan/ada/augment.py b/mmgen/models/architectures/stylegan/ada/augment.py new file mode 100644 index 000000000..28f20e0b2 --- /dev/null +++ b/mmgen/models/architectures/stylegan/ada/augment.py @@ -0,0 +1,771 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. + +import numpy as np +import scipy.signal +import torch + +from mmgen.ops import conv2d_gradfix +from . import grid_sample_gradfix, misc, upfirdn2d + +# ---------------------------------------------------------------------------- +# Coefficients of various wavelet decomposition low-pass filters. + +wavelets = { + 'haar': [0.7071067811865476, 0.7071067811865476], + 'db1': [0.7071067811865476, 0.7071067811865476], + 'db2': [ + -0.12940952255092145, 0.22414386804185735, 0.836516303737469, + 0.48296291314469025 + ], + 'db3': [ + 0.035226291882100656, -0.08544127388224149, -0.13501102001039084, + 0.4598775021193313, 0.8068915093133388, 0.3326705529509569 + ], + 'db4': [ + -0.010597401784997278, 0.032883011666982945, 0.030841381835986965, + -0.18703481171888114, -0.02798376941698385, 0.6308807679295904, + 0.7148465705525415, 0.23037781330885523 + ], + 'db5': [ + 0.003335725285001549, -0.012580751999015526, -0.006241490213011705, + 0.07757149384006515, -0.03224486958502952, -0.24229488706619015, + 0.13842814590110342, 0.7243085284385744, 0.6038292697974729, + 0.160102397974125 + ], + 'db6': [ + -0.00107730108499558, 0.004777257511010651, 0.0005538422009938016, + -0.031582039318031156, 0.02752286553001629, 0.09750160558707936, + -0.12976686756709563, -0.22626469396516913, 0.3152503517092432, + 0.7511339080215775, 0.4946238903983854, 0.11154074335008017 + ], + 'db7': [ + 0.0003537138000010399, -0.0018016407039998328, 0.00042957797300470274, + 0.012550998556013784, -0.01657454163101562, -0.03802993693503463, + 0.0806126091510659, 0.07130921926705004, -0.22403618499416572, + -0.14390600392910627, 0.4697822874053586, 0.7291320908465551, + 0.39653931948230575, 0.07785205408506236 + ], + 'db8': [ + -0.00011747678400228192, 0.0006754494059985568, -0.0003917403729959771, + -0.00487035299301066, 0.008746094047015655, 0.013981027917015516, + -0.04408825393106472, -0.01736930100202211, 0.128747426620186, + 0.00047248457399797254, -0.2840155429624281, -0.015829105256023893, + 0.5853546836548691, 0.6756307362980128, 0.3128715909144659, + 0.05441584224308161 + ], + 'sym2': [ + -0.12940952255092145, 0.22414386804185735, 0.836516303737469, + 0.48296291314469025 + ], + 'sym3': [ + 0.035226291882100656, -0.08544127388224149, -0.13501102001039084, + 0.4598775021193313, 0.8068915093133388, 0.3326705529509569 + ], + 'sym4': [ + -0.07576571478927333, -0.02963552764599851, 0.49761866763201545, + 0.8037387518059161, 0.29785779560527736, -0.09921954357684722, + -0.012603967262037833, 0.0322231006040427 + ], + 'sym5': [ + 0.027333068345077982, 0.029519490925774643, -0.039134249302383094, + 0.1993975339773936, 0.7234076904024206, 0.6339789634582119, + 0.01660210576452232, -0.17532808990845047, -0.021101834024758855, + 0.019538882735286728 + ], + 'sym6': [ + 0.015404109327027373, 0.0034907120842174702, -0.11799011114819057, + -0.048311742585633, 0.4910559419267466, 0.787641141030194, + 0.3379294217276218, -0.07263752278646252, -0.021060292512300564, + 0.04472490177066578, 0.0017677118642428036, -0.007800708325034148 + ], + 'sym7': [ + 0.002681814568257878, -0.0010473848886829163, -0.01263630340325193, + 0.03051551316596357, 0.0678926935013727, -0.049552834937127255, + 0.017441255086855827, 0.5361019170917628, 0.767764317003164, + 0.2886296317515146, -0.14004724044296152, -0.10780823770381774, + 0.004010244871533663, 0.010268176708511255 + ], + 'sym8': [ + -0.0033824159510061256, -0.0005421323317911481, 0.03169508781149298, + 0.007607487324917605, -0.1432942383508097, -0.061273359067658524, + 0.4813596512583722, 0.7771857517005235, 0.3644418948353314, + -0.05194583810770904, -0.027219029917056003, 0.049137179673607506, + 0.003808752013890615, -0.01495225833704823, -0.0003029205147213668, + 0.0018899503327594609 + ], +} + +# ---------------------------------------------------------------------------- +# Helpers for constructing transformation matrices. + + +def matrix(*rows, device=None): + """Constructing transformation matrices. + Args: + device (str|torch.device, optional): Matrix device. Defaults to None. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + assert all(len(row) == len(rows[0]) for row in rows) + elems = [x for row in rows for x in row] + ref = [x for x in elems if isinstance(x, torch.Tensor)] + if len(ref) == 0: + return misc.constant(np.asarray(rows), device=device) + assert device is None or device == ref[0].device + # change `x.float()` to support pt1.5 + elems = [ + x.float() if isinstance(x, torch.Tensor) else misc.constant( + x, shape=ref[0].shape, device=ref[0].device) for x in elems + ] + return torch.stack(elems, dim=-1).reshape(ref[0].shape + (len(rows), -1)) + + +def translate2d(tx, ty, **kwargs): + """Construct 2d translation matrix. + Args: + tx (float): X-direction translation amount. + ty (float): Y-direction translation amount. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return matrix([1, 0, tx], [0, 1, ty], [0, 0, 1], **kwargs) + + +def translate3d(tx, ty, tz, **kwargs): + """Construct 3d translation matrix. + Args: + tx (float): X-direction translation amount. + ty (float): Y-direction translation amount. + tz (float): Z-direction translation amount. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return matrix([1, 0, 0, tx], [0, 1, 0, ty], [0, 0, 1, tz], [0, 0, 0, 1], + **kwargs) + + +def scale2d(sx, sy, **kwargs): + """Construct 2d scaling matrix. + Args: + sx (float): X-direction scaling coefficient. + sy (float): Y-direction scaling coefficient. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return matrix([sx, 0, 0], [0, sy, 0], [0, 0, 1], **kwargs) + + +def scale3d(sx, sy, sz, **kwargs): + """Construct 3d scaling matrix. + Args: + sx (float): X-direction scaling coefficient. + sy (float): Y-direction scaling coefficient. + sz (float): Z-direction scaling coefficient. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return matrix([sx, 0, 0, 0], [0, sy, 0, 0], [0, 0, sz, 0], [0, 0, 0, 1], + **kwargs) + + +def rotate2d(theta, **kwargs): + """Construct 2d rotating matrix. + Args: + theta (float): Counter-clock wise rotation angle. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return matrix([torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], [0, 0, 1], **kwargs) + + +def rotate3d(v, theta, **kwargs): + """Constructing 3d rotating matrix. + Args: + v (torch.Tensor): Luma axis. + theta (float): Rotate theta counter-clock wise with ``v`` as the axis. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + vx = v[..., 0] + vy = v[..., 1] + vz = v[..., 2] + s = torch.sin(theta) + c = torch.cos(theta) + cc = 1 - c + return matrix( + [vx * vx * cc + c, vx * vy * cc - vz * s, vx * vz * cc + vy * s, 0], + [vy * vx * cc + vz * s, vy * vy * cc + c, vy * vz * cc - vx * s, 0], + [vz * vx * cc - vy * s, vz * vy * cc + vx * s, vz * vz * cc + c, 0], + [0, 0, 0, 1], **kwargs) + + +def translate2d_inv(tx, ty, **kwargs): + """Construct inverse matrix of 2d translation matrix. + Args: + tx (float): X-direction translation amount. + ty (float): Y-direction translation amount. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return translate2d(-tx, -ty, **kwargs) + + +def scale2d_inv(sx, sy, **kwargs): + """Construct inverse matrix of 2d scaling matrix. + Args: + sx (float): X-direction scaling coefficient. + sy (float): Y-direction scaling coefficient. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return scale2d(1 / sx, 1 / sy, **kwargs) + + +def rotate2d_inv(theta, **kwargs): + """Construct inverse matrix of 2d rotating matrix. + Args: + theta (float): Counter-clock wise rotation angle. + Returns: + ndarry | Tensor : Transformation matrices in np.ndarry or torch.Tensor + format. + """ + return rotate2d(-theta, **kwargs) + + +# ---------------------------------------------------------------------------- +# Versatile image augmentation pipeline from the paper +# "Training Generative Adversarial Networks with Limited Data". +# +# All augmentations are disabled by default; individual augmentations can +# be enabled by setting their probability multipliers to 1. + + +class AugmentPipe(torch.nn.Module): + """Augmentation pipeline include multiple geometric and color + transformations. + Note: The meaning of arguments are written in the comments of + ``__init__`` function. + """ + def __init__( + self, + xflip=0, + rotate90=0, + xint=0, + xint_max=0.125, + scale=0, + rotate=0, + aniso=0, + xfrac=0, + scale_std=0.2, + rotate_max=1, + aniso_std=0.2, + xfrac_std=0.125, + brightness=0, + contrast=0, + lumaflip=0, + hue=0, + saturation=0, + brightness_std=0.2, + contrast_std=0.5, + hue_max=1, + saturation_std=1, + imgfilter=0, + imgfilter_bands=[1, 1, 1, 1], + imgfilter_std=1, + noise=0, + cutout=0, + noise_std=0.1, + cutout_size=0.5, + ): + super().__init__() + self.register_buffer('p', torch.ones( + [])) # Overall multiplier for augmentation probability. + + # Pixel blitting. + self.xflip = float(xflip) # Probability multiplier for x-flip. + self.rotate90 = float( + rotate90) # Probability multiplier for 90 degree rotations. + self.xint = float( + xint) # Probability multiplier for integer translation. + self.xint_max = float( + xint_max + ) # Range of integer translation, relative to image dimensions. + + # General geometric transformations. + self.scale = float( + scale) # Probability multiplier for isotropic scaling. + self.rotate = float( + rotate) # Probability multiplier for arbitrary rotation. + self.aniso = float( + aniso) # Probability multiplier for anisotropic scaling. + self.xfrac = float( + xfrac) # Probability multiplier for fractional translation. + self.scale_std = float( + scale_std) # Log2 standard deviation of isotropic scaling. + self.rotate_max = float( + rotate_max) # Range of arbitrary rotation, 1 = full circle. + self.aniso_std = float( + aniso_std) # Log2 standard deviation of anisotropic scaling. + self.xfrac_std = float( + xfrac_std + ) # Standard deviation of frational translation, relative to img dims. + + # Color transformations. + self.brightness = float( + brightness) # Probability multiplier for brightness. + self.contrast = float(contrast) # Probability multiplier for contrast. + self.lumaflip = float( + lumaflip) # Probability multiplier for luma flip. + self.hue = float(hue) # Probability multiplier for hue rotation. + self.saturation = float( + saturation) # Probability multiplier for saturation. + self.brightness_std = float( + brightness_std) # Standard deviation of brightness. + self.contrast_std = float( + contrast_std) # Log2 standard deviation of contrast. + self.hue_max = float( + hue_max) # Range of hue rotation, 1 = full circle. + self.saturation_std = float( + saturation_std) # Log2 standard deviation of saturation. + + # Image-space filtering. + self.imgfilter = float( + imgfilter) # Probability multiplier for image-space filtering. + self.imgfilter_bands = list( + imgfilter_bands + ) # Probability multipliers for individual frequency bands. + self.imgfilter_std = float( + imgfilter_std + ) # Log2 standard deviation of image-space filter amplification. + + # Image-space corruptions. + self.noise = float( + noise) # Probability multiplier for additive RGB noise. + self.cutout = float(cutout) # Probability multiplier for cutout. + self.noise_std = float( + noise_std) # Standard deviation of additive RGB noise. + self.cutout_size = float( + cutout_size + ) # Size of the cutout rectangle, relative to image dimensions. + + # Setup orthogonal lowpass filter for geometric augmentations. + self.register_buffer('Hz_geom', + upfirdn2d.setup_filter(wavelets['sym6'])) + + # Construct filter bank for image-space filtering. + Hz_lo = np.asarray(wavelets['sym2']) # H(z) + Hz_hi = Hz_lo * ((-1)**np.arange(Hz_lo.size)) # H(-z) + Hz_lo2 = np.convolve(Hz_lo, Hz_lo[::-1]) / 2 # H(z) * H(z^-1) / 2 + Hz_hi2 = np.convolve(Hz_hi, Hz_hi[::-1]) / 2 # H(-z) * H(-z^-1) / 2 + Hz_fbank = np.eye(4, 1) # Bandpass(H(z), b_i) + for i in range(1, Hz_fbank.shape[0]): + Hz_fbank = np.dstack([Hz_fbank, np.zeros_like(Hz_fbank) + ]).reshape(Hz_fbank.shape[0], -1)[:, :-1] + Hz_fbank = scipy.signal.convolve(Hz_fbank, [Hz_lo2]) + Hz_fbank[i, (Hz_fbank.shape[1] - Hz_hi2.size) // + 2:(Hz_fbank.shape[1] + Hz_hi2.size) // 2] += Hz_hi2 + self.register_buffer('Hz_fbank', + torch.as_tensor(Hz_fbank, dtype=torch.float32)) + + def forward(self, images, debug_percentile=None): + assert isinstance(images, torch.Tensor) and images.ndim == 4 + batch_size, num_channels, height, width = images.shape + device = images.device + if debug_percentile is not None: + debug_percentile = torch.as_tensor(debug_percentile, + dtype=torch.float32, + device=device) + + # ------------------------------------- + # Select parameters for pixel blitting. + # ------------------------------------- + + # Initialize inverse homogeneous 2D transform: + # G_inv @ pixel_out ==> pixel_in + I_3 = torch.eye(3, device=device) + G_inv = I_3 + + # Apply x-flip with probability (xflip * strength). + if self.xflip > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 2) + i = torch.where( + torch.rand([batch_size], device=device) < self.xflip * self.p, + i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1) + + # Apply 90 degree rotations with probability (rotate90 * strength). + if self.rotate90 > 0: + i = torch.floor(torch.rand([batch_size], device=device) * 4) + i = torch.where( + torch.rand([batch_size], device=device) < + self.rotate90 * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 4)) + G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i) + + # Apply integer translation with probability (xint * strength). + if self.xint > 0: + t = (torch.rand([batch_size, 2], device=device) * 2 - + 1) * self.xint_max + t = torch.where( + torch.rand([batch_size, 1], device=device) < + self.xint * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like(t, + (debug_percentile * 2 - 1) * self.xint_max) + G_inv = G_inv @ translate2d_inv(torch.round(t[:, 0] * width), + torch.round(t[:, 1] * height)) + + # -------------------------------------------------------- + # Select parameters for general geometric transformations. + # -------------------------------------------------------- + + # support for pt1.5 (pt1.5 does not contain exp2) + _scalor_log2 = torch.log( + torch.tensor(2., device=images.device, dtype=images.dtype)) + + # Apply isotropic scaling with probability (scale * strength). + if self.scale > 0: + s = torch.exp( + torch.randn([batch_size], device=device) * self.scale_std * + _scalor_log2) + s = torch.where( + torch.rand([batch_size], device=device) < self.scale * self.p, + s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like( + s, + torch.exp2( + torch.erfinv(debug_percentile * 2 - 1) * + self.scale_std)) + G_inv = G_inv @ scale2d_inv(s, s) + + # Apply pre-rotation with probability p_rot. + p_rot = 1 - torch.sqrt( + (1 - self.rotate * self.p).clamp(0, 1)) # P(pre OR post) = p + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - + 1) * np.pi * self.rotate_max + theta = torch.where( + torch.rand([batch_size], device=device) < p_rot, theta, + torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * + np.pi * self.rotate_max) + G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling. + + # Apply anisotropic scaling with probability (aniso * strength). + if self.aniso > 0: + s = torch.exp( + torch.randn([batch_size], device=device) * self.aniso_std * + _scalor_log2) + s = torch.where( + torch.rand([batch_size], device=device) < self.aniso * self.p, + s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like( + s, + torch.exp2( + torch.erfinv(debug_percentile * 2 - 1) * + self.aniso_std)) + G_inv = G_inv @ scale2d_inv(s, 1 / s) + + # Apply post-rotation with probability p_rot. + if self.rotate > 0: + theta = (torch.rand([batch_size], device=device) * 2 - + 1) * np.pi * self.rotate_max + theta = torch.where( + torch.rand([batch_size], device=device) < p_rot, theta, + torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.zeros_like(theta) + G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling. + + # Apply fractional translation with probability (xfrac * strength). + if self.xfrac > 0: + t = torch.randn([batch_size, 2], device=device) * self.xfrac_std + t = torch.where( + torch.rand([batch_size, 1], device=device) < + self.xfrac * self.p, t, torch.zeros_like(t)) + if debug_percentile is not None: + t = torch.full_like( + t, + torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std) + G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height) + + # ---------------------------------- + # Execute geometric transformations. + # ---------------------------------- + + # Execute if the transform is not identity. + if G_inv is not I_3: + + # Calculate padding. + cx = (width - 1) / 2 + cy = (height - 1) / 2 + cp = matrix([-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], + device=device) # [idx, xyz] + cp = G_inv @ cp.t() # [batch, xyz, idx] + Hz_pad = self.Hz_geom.shape[0] // 4 + margin = cp[:, :2, :].permute(1, 0, + 2).flatten(1) # [xy, batch * idx] + margin = torch.cat([-margin, + margin]).max(dim=1).values # [x0, y0, x1, y1] + margin = margin + misc.constant( + [Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device) + margin = margin.max(misc.constant([0, 0] * 2, device=device)) + margin = margin.min( + misc.constant([width - 1, height - 1] * 2, device=device)) + mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) + + # Pad image and adjust origin. + images = torch.nn.functional.pad(input=images, + pad=[mx0, mx1, my0, my1], + mode='reflect') + G_inv = translate2d(torch.true_divide(mx0 - mx1, 2), + torch.true_divide(my0 - my1, 2)) @ G_inv + + # Upsample. + images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) + G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv( + 2, 2, device=device) + G_inv = translate2d(-0.5, -0.5, + device=device) @ G_inv @ translate2d_inv( + -0.5, -0.5, device=device) + + # Execute transformation. + shape = [ + batch_size, num_channels, (height + Hz_pad * 2) * 2, + (width + Hz_pad * 2) * 2 + ] + G_inv = scale2d(2 / images.shape[3], + 2 / images.shape[2], + device=device) @ G_inv @ scale2d_inv( + 2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :], + size=shape, + align_corners=False) + images = grid_sample_gradfix.grid_sample(images, grid) + + # Downsample and crop. + images = upfirdn2d.downsample2d(x=images, + f=self.Hz_geom, + down=2, + padding=-Hz_pad * 2, + flip_filter=True) + + # -------------------------------------------- + # Select parameters for color transformations. + # -------------------------------------------- + + # Initialize homogeneous 3D transformation matrix: + # C @ color_in ==> color_out + I_4 = torch.eye(4, device=device) + C = I_4 + + # Apply brightness with probability (brightness * strength). + if self.brightness > 0: + b = torch.randn([batch_size], device=device) * self.brightness_std + b = torch.where( + torch.rand([batch_size], device=device) < + self.brightness * self.p, b, torch.zeros_like(b)) + if debug_percentile is not None: + b = torch.full_like( + b, + torch.erfinv(debug_percentile * 2 - 1) * + self.brightness_std) + C = translate3d(b, b, b) @ C + + # Apply contrast with probability (contrast * strength). + if self.contrast > 0: + c = torch.exp( + torch.randn([batch_size], device=device) * self.contrast_std * + _scalor_log2) + c = torch.where( + torch.rand([batch_size], device=device) < + self.contrast * self.p, c, torch.ones_like(c)) + if debug_percentile is not None: + c = torch.full_like( + c, + torch.exp2( + torch.erfinv(debug_percentile * 2 - 1) * + self.contrast_std)) + C = scale3d(c, c, c) @ C + + # Apply luma flip with probability (lumaflip * strength). + v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), + device=device) # Luma axis. + if self.lumaflip > 0: + i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) + i = torch.where( + torch.rand([batch_size, 1, 1], device=device) < + self.lumaflip * self.p, i, torch.zeros_like(i)) + if debug_percentile is not None: + i = torch.full_like(i, torch.floor(debug_percentile * 2)) + C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection. + + # Apply hue rotation with probability (hue * strength). + if self.hue > 0 and num_channels > 1: + theta = (torch.rand([batch_size], device=device) * 2 - + 1) * np.pi * self.hue_max + theta = torch.where( + torch.rand([batch_size], device=device) < self.hue * self.p, + theta, torch.zeros_like(theta)) + if debug_percentile is not None: + theta = torch.full_like(theta, (debug_percentile * 2 - 1) * + np.pi * self.hue_max) + C = rotate3d(v, theta) @ C # Rotate around v. + + # Apply saturation with probability (saturation * strength). + if self.saturation > 0 and num_channels > 1: + s = torch.exp( + torch.randn([batch_size, 1, 1], device=device) * + self.saturation_std * _scalor_log2) + s = torch.where( + torch.rand([batch_size, 1, 1], device=device) < + self.saturation * self.p, s, torch.ones_like(s)) + if debug_percentile is not None: + s = torch.full_like( + s, + torch.exp2( + torch.erfinv(debug_percentile * 2 - 1) * + self.saturation_std)) + C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C + + # ------------------------------ + # Execute color transformations. + # ------------------------------ + + # Execute if the transform is not identity. + if C is not I_4: + images = images.reshape([batch_size, num_channels, height * width]) + if num_channels == 3: + images = C[:, :3, :3] @ images + C[:, :3, 3:] + elif num_channels == 1: + C = C[:, :3, :].mean(dim=1, keepdims=True) + images = images * C[:, :, :3].sum(dim=2, + keepdims=True) + C[:, :, 3:] + else: + raise ValueError( + 'Image must be RGB (3 channels) or L (1 channel)') + images = images.reshape([batch_size, num_channels, height, width]) + + # ---------------------- + # Image-space filtering. + # ---------------------- + + if self.imgfilter > 0: + num_bands = self.Hz_fbank.shape[0] + assert len(self.imgfilter_bands) == num_bands + expected_power = misc.constant( + np.array([10, 1, 1, 1]) / 13, + device=device) # Expected power spectrum (1/f). + + # Apply amplification for each band with probability + # (imgfilter * strength * band_strength). + g = torch.ones([batch_size, num_bands], + device=device) # Global gain vector (identity). + for i, band_strength in enumerate(self.imgfilter_bands): + t_i = torch.exp( + torch.randn([batch_size], device=device) * + self.imgfilter_std * _scalor_log2) + t_i = torch.where( + torch.rand([batch_size], device=device) < + self.imgfilter * self.p * band_strength, t_i, + torch.ones_like(t_i)) + if debug_percentile is not None: + t_i = torch.full_like( + t_i, + torch.exp2( + torch.erfinv(debug_percentile * 2 - 1) * + self.imgfilter_std) + ) if band_strength > 0 else torch.ones_like(t_i) + t = torch.ones([batch_size, num_bands], + device=device) # Temporary gain vector. + t[:, i] = t_i # Replace i'th element. + t = t / (expected_power * t.square()).sum( + dim=-1, keepdims=True).sqrt() # Normalize power. + g = g * t # Accumulate into global gain. + + # Construct combined amplification filter. + Hz_prime = g @ self.Hz_fbank # [batch, tap] + Hz_prime = Hz_prime.unsqueeze(1).repeat( + [1, num_channels, 1]) # [batch, channels, tap] + Hz_prime = Hz_prime.reshape([batch_size * num_channels, 1, + -1]) # [batch * channels, 1, tap] + + # Apply filter. + p = self.Hz_fbank.shape[1] // 2 + images = images.reshape( + [1, batch_size * num_channels, height, width]) + images = torch.nn.functional.pad(input=images, + pad=[p, p, p, p], + mode='reflect') + images = conv2d_gradfix.conv2d(input=images, + weight=Hz_prime.unsqueeze(2), + groups=batch_size * num_channels) + images = conv2d_gradfix.conv2d(input=images, + weight=Hz_prime.unsqueeze(3), + groups=batch_size * num_channels) + images = images.reshape([batch_size, num_channels, height, width]) + + # ------------------------ + # Image-space corruptions. + # ------------------------ + + # Apply additive RGB noise with probability (noise * strength). + if self.noise > 0: + sigma = torch.randn([batch_size, 1, 1, 1], + device=device).abs() * self.noise_std + sigma = torch.where( + torch.rand([batch_size, 1, 1, 1], device=device) < + self.noise * self.p, sigma, torch.zeros_like(sigma)) + if debug_percentile is not None: + sigma = torch.full_like( + sigma, + torch.erfinv(debug_percentile) * self.noise_std) + images = images + torch.randn( + [batch_size, num_channels, height, width], + device=device) * sigma + + # Apply cutout with probability (cutout * strength). + if self.cutout > 0: + size = torch.full([batch_size, 2, 1, 1, 1], + self.cutout_size, + device=device) + size = torch.where( + torch.rand([batch_size, 1, 1, 1, 1], device=device) < + self.cutout * self.p, size, torch.zeros_like(size)) + center = torch.rand([batch_size, 2, 1, 1, 1], device=device) + if debug_percentile is not None: + size = torch.full_like(size, self.cutout_size) + center = torch.full_like(center, debug_percentile) + coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) + coord_y = torch.arange(height, + device=device).reshape([1, 1, -1, 1]) + mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= + size[:, 0] / 2) + mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= + size[:, 1] / 2) + mask = torch.logical_or(mask_x, mask_y).to(torch.float32) + images = images * mask + + return images diff --git a/mmgen/models/architectures/stylegan/ada/grid_sample_gradfix.py b/mmgen/models/architectures/stylegan/ada/grid_sample_gradfix.py new file mode 100644 index 000000000..fd6cbb4f2 --- /dev/null +++ b/mmgen/models/architectures/stylegan/ada/grid_sample_gradfix.py @@ -0,0 +1,108 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# NVIDIA CORPORATION and its licensors retain all intellectual property +# and proprietary rights in and to this software, related documentation +# and any modifications thereto. Any use, reproduction, disclosure or +# distribution of this software and related documentation without an express +# license agreement from NVIDIA CORPORATION is strictly prohibited. +"""Custom replacement for `torch.nn.functional.grid_sample` that supports +arbitrarily high order gradients between the input and output. + +Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`, +`align_corners=False`. +""" + +import warnings + +import torch + +# pylint: disable=redefined-builtin +# pylint: disable=arguments-differ +# pylint: disable=protected-access + +# ---------------------------------------------------------------------------- + +enabled = True # Enable the custom op by setting this to true. + +# ---------------------------------------------------------------------------- + + +def grid_sample(input, grid): + if _should_use_custom_op(): + return _GridSample2dForward.apply(input, grid) + return torch.nn.functional.grid_sample( + input=input, + grid=grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + + +# ---------------------------------------------------------------------------- + + +def _should_use_custom_op(): + if not enabled: + return False + if any( + torch.__version__.startswith(x) + for x in ['1.5.', '1.6.', '1.7.', '1.8.', '1.9.', '1.10.']): + return True + warnings.warn( + f'grid_sample_gradfix not supported on PyTorch {torch.__version__}.' + ' Falling back to torch.nn.functional.grid_sample().') + return False + + +# ---------------------------------------------------------------------------- + + +class _GridSample2dForward(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, grid): + assert input.ndim == 4 + assert grid.ndim == 4 + output = torch.nn.functional.grid_sample( + input=input, + grid=grid, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + ctx.save_for_backward(input, grid) + return output + + @staticmethod + def backward(ctx, grad_output): + input, grid = ctx.saved_tensors + grad_input, grad_grid = _GridSample2dBackward.apply( + grad_output, input, grid) + return grad_input, grad_grid + + +# ---------------------------------------------------------------------------- + + +class _GridSample2dBackward(torch.autograd.Function): + + @staticmethod + def forward(ctx, grad_output, input, grid): + op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') + grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) + ctx.save_for_backward(grid) + return grad_input, grad_grid + + @staticmethod + def backward(ctx, grad2_grad_input, grad2_grad_grid): + _ = grad2_grad_grid # unused + grid, = ctx.saved_tensors + grad2_grad_output = None + grad2_input = None + grad2_grid = None + + if ctx.needs_input_grad[0]: + grad2_grad_output = _GridSample2dForward.apply( + grad2_grad_input, grid) + + assert not ctx.needs_input_grad[2] + return grad2_grad_output, grad2_input, grad2_grid diff --git a/mmgen/models/architectures/stylegan/ada/misc.py b/mmgen/models/architectures/stylegan/ada/misc.py new file mode 100644 index 000000000..6a7b95597 --- /dev/null +++ b/mmgen/models/architectures/stylegan/ada/misc.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +# Cached construction of constant tensors. Avoids CPU=>GPU copy when the +# same constant is used multiple times. + +_constant_cache = dict() + + +def constant(value, shape=None, dtype=None, device=None, memory_format=None): + value = np.asarray(value) + if shape is not None: + shape = tuple(shape) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device('cpu') + if memory_format is None: + memory_format = torch.contiguous_format + + key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, + memory_format) + tensor = _constant_cache.get(key, None) + if tensor is None: + tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) + if shape is not None: + tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) + tensor = tensor.contiguous(memory_format=memory_format) + _constant_cache[key] = tensor + return tensor diff --git a/mmgen/models/architectures/stylegan/ada/upfirdn2d.py b/mmgen/models/architectures/stylegan/ada/upfirdn2d.py new file mode 100644 index 000000000..0d15167f1 --- /dev/null +++ b/mmgen/models/architectures/stylegan/ada/upfirdn2d.py @@ -0,0 +1,193 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmcv.ops.upfirdn2d import upfirdn2d + + +def _parse_scaling(scaling): + if isinstance(scaling, int): + scaling = [scaling, scaling] + assert isinstance(scaling, (list, tuple)) + assert all(isinstance(x, int) for x in scaling) + sx, sy = scaling + assert sx >= 1 and sy >= 1 + return sx, sy + + +def _parse_padding(padding): + if isinstance(padding, int): + padding = [padding, padding] + assert isinstance(padding, (list, tuple)) + assert all(isinstance(x, int) for x in padding) + if len(padding) == 2: + padx, pady = padding + padding = [padx, padx, pady, pady] + padx0, padx1, pady0, pady1 = padding + return padx0, padx1, pady0, pady1 + + +def _get_filter_size(f): + if f is None: + return 1, 1 + assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] + fw = f.shape[-1] + fh = f.shape[0] + fw = int(fw) + fh = int(fh) + assert fw >= 1 and fh >= 1 + return fw, fh + + +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): + r"""Upsample a batch of 2D images using the given 2D FIR filter. + By default, the result is padded so that its shape is a multiple of the + input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + up: Integer upsampling factor. Can be a single int or a + list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the output. Can be a single + number or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, + y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` + (default: `'cuda'`). + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]` + """ + upx, upy = _parse_scaling(up) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw + upx - 1) // 2, + padx1 + (fw - upx) // 2, + pady0 + (fh + upy - 1) // 2, + pady1 + (fh - upy) // 2, + ] + + gain = gain * upx * upy + f = f * (gain**(f.ndim / 2)) + if flip_filter: + f = f.flip(list(range(f.ndim))) + if f.ndim == 1: + x = upfirdn2d(x, f.unsqueeze(0), up=(upx, 1), pad=(p[0], p[1], 0, 0)) + x = upfirdn2d(x, f.unsqueeze(1), up=(1, upy), pad=(0, 0, p[2], p[3])) + return x + + +def setup_filter(f, + device=torch.device('cpu'), + normalize=True, + flip_filter=False, + gain=1, + separable=None): + r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. + Args: + f: Torch tensor, numpy array, or python list of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), + `[]` (impulse), or + `None` (identity). + device: Result device (default: cpu). + normalize: Normalize the filter so that it retains the magnitude + for constant input signal (DC)? (default: True). + flip_filter: Flip the filter? (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + separable: Return a separable filter? (default: select automatically) + Returns: + Float32 tensor of the shape + `[filter_height, filter_width]` (non-separable) or + `[filter_taps]` (separable). + """ + # Validate. + if f is None: + f = 1 + f = torch.as_tensor(f, dtype=torch.float32) + assert f.ndim in [0, 1, 2] + assert f.numel() > 0 + if f.ndim == 0: + f = f[np.newaxis] + + # Separable? + if separable is None: + separable = (f.ndim == 1 and f.numel() >= 8) + if f.ndim == 1 and not separable: + f = f.ger(f) + assert f.ndim == (1 if separable else 2) + + # Apply normalize, flip, gain, and device. + if normalize: + f /= f.sum() + if flip_filter: + f = f.flip(list(range(f.ndim))) + f = f * (gain**(f.ndim / 2)) + f = f.to(device=device) + return f + + +def downsample2d(x, + f, + down=2, + padding=0, + flip_filter=False, + gain=1, + impl='cuda'): + r"""Downsample a batch of 2D images using the given 2D FIR filter. + By default, the result is padded so that its shape is a fraction of the + input. + User-specified padding is applied on top of that, with negative values + indicating cropping. Pixels outside the image are assumed to be zero. + Args: + x: Float32/float64/float16 input tensor of the shape + `[batch_size, num_channels, in_height, in_width]`. + f: Float32 FIR filter of the shape + `[filter_height, filter_width]` (non-separable), + `[filter_taps]` (separable), or + `None` (identity). + down: Integer downsampling factor. Can be a single int or a + list/tuple + `[x, y]` (default: 1). + padding: Padding with respect to the input. Can be a single number + or a + list/tuple `[x, y]` or `[x_before, x_after, y_before, + y_after]` + (default: 0). + flip_filter: False = convolution, True = correlation (default: False). + gain: Overall scaling factor for signal magnitude (default: 1). + impl: Implementation to use. Can be `'ref'` or `'cuda'` + (default: `'cuda'`). + Returns: + Tensor of the shape `[batch_size, num_channels, out_height, out_width]` + """ + downx, downy = _parse_scaling(down) + padx0, padx1, pady0, pady1 = _parse_padding(padding) + fw, fh = _get_filter_size(f) + p = [ + padx0 + (fw - downx + 1) // 2, + padx1 + (fw - downx) // 2, + pady0 + (fh - downy + 1) // 2, + pady1 + (fh - downy) // 2, + ] + if flip_filter: + f = f.flip(list(range(f.ndim))) + if f.ndim == 1: + x = upfirdn2d(x, + f.unsqueeze(0), + down=(downx, 1), + pad=(p[0], p[1], 0, 0)) + x = upfirdn2d(x, + f.unsqueeze(1), + down=(1, downy), + pad=(0, 0, p[2], p[3])) + return x diff --git a/mmgen/models/architectures/stylegan/generator_discriminator_v2.py b/mmgen/models/architectures/stylegan/generator_discriminator_v2.py index 01ea222db..ecd264023 100644 --- a/mmgen/models/architectures/stylegan/generator_discriminator_v2.py +++ b/mmgen/models/architectures/stylegan/generator_discriminator_v2.py @@ -10,7 +10,9 @@ from mmgen.core.runners.fp16_utils import auto_fp16 from mmgen.models.architectures import PixelNorm from mmgen.models.architectures.common import get_module_device -from mmgen.models.builder import MODULES +from mmgen.models.builder import MODULES, build_module +from .ada.augment import AugmentPipe +from .ada.misc import constant from .modules.styleganv2_modules import (ConstantInput, ConvDownLayer, EqualLinearActModule, ModMBStddevLayer, ModulatedStyleConv, @@ -501,6 +503,10 @@ class StyleGAN2Discriminator(nn.Module): fp32 if not `fp16_enabled`. This argument is designed to deal with the cases where some modules are run in FP16 and others in FP32. Defaults to True. + input_bgr2rgb (bool, optional): Whether to reformat the input channels + with order `rgb`. Since we provide several converted weights, + whose input order is `rgb`. You can set this argument to True if + you want to finetune on converted weights. Defaults to False. pretrained (dict | None, optional): Information for pretained models. The necessary key is 'ckpt_path'. Besides, you can also provide 'prefix' to load the generator part from the whole state dict. @@ -516,6 +522,7 @@ def __init__(self, fp16_enabled=False, out_fp32=True, convert_input_fp32=True, + input_bgr2rgb=False, pretrained=None): super().__init__() self.num_fp16_scale = num_fp16_scales @@ -572,6 +579,8 @@ def __init__(self, act_cfg=dict(type='fused_bias')), EqualLinearActModule(channels[4], 1), ) + + self.input_bgr2rgb = input_bgr2rgb if pretrained is not None: self._load_pretrained_model(**pretrained) @@ -595,6 +604,10 @@ def forward(self, x): Returns: torch.Tensor: Predict score for the input image. """ + # This setting was used to finetune on converted weights + if self.input_bgr2rgb: + x = x[:, [2, 1, 0], ...] + x = self.convs(x) x = self.mbstd_layer(x) @@ -605,3 +618,88 @@ def forward(self, x): x = self.final_linear(x) return x + + +@MODULES.register_module() +class ADAStyleGAN2Discriminator(StyleGAN2Discriminator): + + def __init__(self, in_size, *args, data_aug=None, **kwargs): + """StyleGANv2 Discriminator with adaptive augmentation. + + Args: + in_size (int): The input size of images. + data_aug (dict, optional): Config for data + augmentation. Defaults to None. + """ + super().__init__(in_size, *args, **kwargs) + self.with_ada = data_aug is not None + if self.with_ada: + self.ada_aug = build_module(data_aug) + self.ada_aug.requires_grad = False + self.log_size = int(np.log2(in_size)) + + def forward(self, x): + """Forward function.""" + if self.with_ada: + x = self.ada_aug.aug_pipeline(x) + return super().forward(x) + + +@MODULES.register_module() +class ADAAug(nn.Module): + """Data Augmentation Module for Adaptive Discriminator augmentation. + + Args: + aug_pipeline (dict, optional): Config for augmentation pipeline. + Defaults to None. + update_interval (int, optional): Interval for updating + augmentation probability. Defaults to 4. + augment_initial_p (float, optional): Initial augmentation + probability. Defaults to 0.. + ada_target (float, optional): ADA target. Defaults to 0.6. + ada_kimg (int, optional): ADA training duration. Defaults to 500. + """ + + def __init__(self, + aug_pipeline=None, + update_interval=4, + augment_initial_p=0., + ada_target=0.6, + ada_kimg=500, + use_slow_aug=False): + super().__init__() + self.aug_pipeline = AugmentPipe(**aug_pipeline) + self.update_interval = update_interval + self.ada_kimg = ada_kimg + self.ada_target = ada_target + + self.aug_pipeline.p.copy_(torch.tensor(augment_initial_p)) + + # this log buffer stores two numbers: num_scalars, sum_scalars. + self.register_buffer('log_buffer', torch.zeros((2, ))) + + def update(self, iteration=0, num_batches=0): + """Update Augment probability. + + Args: + iteration (int, optional): Training iteration. + Defaults to 0. + num_batches (int, optional): The number of reals batches. + Defaults to 0. + """ + + if (iteration + 1) % self.update_interval == 0: + + adjust_step = float(num_batches * self.update_interval) / float( + self.ada_kimg * 1000.) + + # get the mean value as the ada heuristic + ada_heuristic = self.log_buffer[1] / self.log_buffer[0] + adjust = np.sign(ada_heuristic.item() - + self.ada_target) * adjust_step + # update the augment p + # Note that p may be bigger than 1.0 + self.aug_pipeline.p.copy_((self.aug_pipeline.p + adjust).max( + constant(0, device=self.log_buffer.device))) + + self.log_buffer = self.log_buffer * 0. diff --git a/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py b/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py index daa481dab..ffeeefad8 100644 --- a/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py +++ b/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py @@ -461,13 +461,19 @@ def __init__( int(pad_hi[1]) ] - def forward(self, - x, - w, - noise_mode='random', - force_fp32=False, - update_emas=False): - assert noise_mode in ['random', 'const', 'none'] # unused + def forward(self, x, w, force_fp32=False, update_emas=False): + """Forward function for synthesis layer. + Args: + x (torch.Tensor): Input feature map tensor. + w (torch.Tensor): Input style tensor. + force_fp32 (bool, optional): Force fp32 ignore the weights. + Defaults to True. + update_emas (bool, optional): Whether update moving average of + input magnitude. Defaults to False. + + Returns: + torch.Tensor: Output feature map tensor map. + """ # Track input magnitude. if update_emas: @@ -517,6 +523,18 @@ def forward(self, @staticmethod def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): + """Design lowpass filter giving related arguments., + Args: + numtaps (int): Length of the filter. `numtaps` must be odd if a + passband includes the Nyquist frequency. + cutoff (float): Cutoff frequency of filter + width (float): The approximate width of the transition region. + fs (float): The sampling frequency of the signal. + radial (bool, optional): Whether use radially symmetric jinc-based + filter. Defaults to False. + Returns: + torch.Tensor: Kernel of lowpass filter. + """ assert numtaps >= 1 # Identity filter. @@ -657,6 +675,7 @@ def __init__( self.layer_names.append(name) def forward(self, ws, **layer_kwargs): + """Forward function.""" ws = ws.to(torch.float32).unbind(dim=1) # Execute layers. diff --git a/mmgen/models/gans/static_unconditional_gan.py b/mmgen/models/gans/static_unconditional_gan.py index 41b662082..322af11f2 100644 --- a/mmgen/models/gans/static_unconditional_gan.py +++ b/mmgen/models/gans/static_unconditional_gan.py @@ -270,6 +270,17 @@ def train_step(self, else: optimizer['generator'].step() + # update ada p + if hasattr(self.discriminator, + 'with_ada') and self.discriminator.with_ada: + self.discriminator.ada_aug.log_buffer[0] += batch_size + self.discriminator.ada_aug.log_buffer[1] += disc_pred_real.sign( + ).sum() + self.discriminator.ada_aug.update( + iteration=curr_iter, num_batches=batch_size) + log_vars_disc['augment'] = ( + self.discriminator.ada_aug.aug_pipeline.p.data.cpu()) + log_vars = {} log_vars.update(log_vars_g) log_vars.update(log_vars_disc) From 628a09bbe4577185487d459d624771dbf993681f Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Thu, 31 Mar 2022 21:57:24 +0800 Subject: [PATCH 02/11] fix lint --- ..._r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py | 100 ++++++++++++++++++ ..._t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py | 97 +++++++++++++++++ ...3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py | 4 +- .../architectures/stylegan/ada/augment.py | 97 +++++++++-------- .../architectures/stylegan/ada/upfirdn2d.py | 12 +-- .../stylegan/generator_discriminator_v2.py | 3 +- .../stylegan/modules/styleganv3_modules.py | 2 +- tests/test_modules/test_stylev2_archs.py | 84 ++++++++++++++- 8 files changed, 342 insertions(+), 57 deletions(-) create mode 100644 configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py create mode 100644 configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py diff --git a/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py b/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py new file mode 100644 index 000000000..a6ec5be48 --- /dev/null +++ b/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py @@ -0,0 +1,100 @@ +_base_ = [ + '../_base_/models/stylegan/stylegan3_base.py', + '../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py' +] + +synthesis_cfg = { + 'type': 'SynthesisNetwork', + 'channel_base': 65536, + 'channel_max': 1024, + 'magnitude_ema_beta': 0.999, + 'conv_kernel': 1, + 'use_radial_filters': True +} +r1_gamma = 3.3 # set by user +d_reg_interval = 16 + +load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa + +# ada settings +aug_kwargs = { + 'xflip': 1, + 'rotate90': 1, + 'xint': 1, + 'scale': 1, + 'rotate': 1, + 'aniso': 1, + 'xfrac': 1, + 'brightness': 1, + 'contrast': 1, + 'lumaflip': 1, + 'hue': 1, + 'saturation': 1 +} + +model = dict( + type='StaticUnconditionalGAN', + generator=dict( + out_size=1024, + img_channels=3, + rgb2bgr=True, + synthesis_cfg=synthesis_cfg), + discriminator=dict( + type='ADAStyleGAN2Discriminator', + in_size=1024, + input_bgr2rgb=True, + data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)), + gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), + disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval)) + +imgs_root = 'data/metfaces/images/' +data = dict( + samples_per_gpu=4, + train=dict(dataset=dict(imgs_root=imgs_root)), + val=dict(imgs_root=imgs_root)) + +ema_half_life = 10. # G_smoothing_kimg + +ema_kimg = 10 +ema_nimg = ema_kimg * 1000 +ema_beta = 0.5 ** (32 / max(ema_nimg, 1e-8)) + +custom_hooks = [ + dict( + type='VisualizeUnconditionalSamples', + output_dir='training_samples', + interval=5000), + dict( + type='ExponentialMovingAverageHook', + module_keys=('generator_ema', ), + interp_mode='lerp', + interp_cfg=dict(momentum=ema_beta), + interval=1, + start_iter=0, + priority='VERY_HIGH') +] + +inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl' +metrics = dict( + fid50k=dict( + type='FID', + num_images=50000, + inception_pkl=inception_pkl, + inception_args=dict(type='StyleGAN'), + bgr2rgb=True)) + +inception_path = None # noqa +evaluation = dict( + type='GenerativeEvalHook', + interval=dict(milestones=[100000],interval=[10000, 5000]), + metrics=dict( + type='FID', + num_images=50000, + inception_pkl=inception_pkl, + inception_args=dict(type='StyleGAN', inception_path=inception_path), + bgr2rgb=True), + sample_kwargs=dict(sample_model='ema')) + +lr_config = None + +total_iters = 160000 diff --git a/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py b/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py new file mode 100644 index 000000000..df5408b58 --- /dev/null +++ b/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py @@ -0,0 +1,97 @@ +_base_ = [ + '../_base_/models/stylegan/stylegan3_base.py', + '../_base_/datasets/ffhq_flip.py', '../_base_/default_runtime.py' +] + +synthesis_cfg = { + 'type': 'SynthesisNetwork', + 'channel_base': 32768, + 'channel_max': 512, + 'magnitude_ema_beta': 0.999 +} +r1_gamma = 6.6 # set by user +d_reg_interval = 16 + +load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa +# ada settings +aug_kwargs = { + 'xflip': 1, + 'rotate90': 1, + 'xint': 1, + 'scale': 1, + 'rotate': 1, + 'aniso': 1, + 'xfrac': 1, + 'brightness': 1, + 'contrast': 1, + 'lumaflip': 1, + 'hue': 1, + 'saturation': 1 +} + +model = dict( + type='StaticUnconditionalGAN', + generator=dict( + out_size=1024, + img_channels=3, + rgb2bgr=True, + synthesis_cfg=synthesis_cfg), + discriminator=dict( + type='ADAStyleGAN2Discriminator', + in_size=1024, + input_bgr2rgb=True, + data_aug=dict(type='ADAAug', aug_pipeline=aug_kwargs, ada_kimg=100)), + gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns'), + disc_auxiliary_loss=dict(loss_weight=r1_gamma / 2.0 * d_reg_interval)) + +imgs_root = 'data/metfaces/images/' +data = dict( + samples_per_gpu=4, + train=dict(dataset=dict(imgs_root=imgs_root)), + val=dict(imgs_root=imgs_root)) + +ema_half_life = 10. # G_smoothing_kimg + +ema_kimg = 10 +ema_nimg = ema_kimg * 1000 +ema_beta = 0.5 ** (32 / max(ema_nimg, 1e-8)) + +custom_hooks = [ + dict( + type='VisualizeUnconditionalSamples', + output_dir='training_samples', + interval=5000), + dict( + type='ExponentialMovingAverageHook', + module_keys=('generator_ema', ), + interp_mode='lerp', + interp_cfg=dict(momentum=ema_beta), + interval=1, + start_iter=0, + priority='VERY_HIGH') +] + +inception_pkl = 'work_dirs/inception_pkl/metface_1024x1024_noflip.pkl' +metrics = dict( + fid50k=dict( + type='FID', + num_images=50000, + inception_pkl=inception_pkl, + inception_args=dict(type='StyleGAN'), + bgr2rgb=True)) + +inception_path = None # set by user +evaluation = dict( + type='GenerativeEvalHook', + interval=dict(milestones=[80000],interval=[10000, 5000]), + metrics=dict( + type='FID', + num_images=50000, + inception_pkl=inception_pkl, + inception_args=dict(type='StyleGAN', inception_path=inception_path), + bgr2rgb=True), + sample_kwargs=dict(sample_model='ema')) + +lr_config = None + +total_iters = 160000 diff --git a/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py b/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py index 9238f12b6..33a5c36f4 100644 --- a/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py +++ b/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py @@ -56,7 +56,7 @@ inception_args=dict(type='StyleGAN'), bgr2rgb=True)) -inception_path = '/mnt/lustre/yangyifei1/repos/fix-s3/work_dirs/cache/inception-2015-12-05.pt' # noqa +inception_path = None # set by user evaluation = dict( type='GenerativeEvalHook', interval=10000, @@ -72,5 +72,3 @@ lr_config = None total_iters = 800002 - -allow_tf32 = False diff --git a/mmgen/models/architectures/stylegan/ada/augment.py b/mmgen/models/architectures/stylegan/ada/augment.py index 28f20e0b2..80badbdd0 100644 --- a/mmgen/models/architectures/stylegan/ada/augment.py +++ b/mmgen/models/architectures/stylegan/ada/augment.py @@ -107,6 +107,7 @@ def matrix(*rows, device=None): """Constructing transformation matrices. + Args: device (str|torch.device, optional): Matrix device. Defaults to None. Returns: @@ -129,6 +130,7 @@ def matrix(*rows, device=None): def translate2d(tx, ty, **kwargs): """Construct 2d translation matrix. + Args: tx (float): X-direction translation amount. ty (float): Y-direction translation amount. @@ -141,6 +143,7 @@ def translate2d(tx, ty, **kwargs): def translate3d(tx, ty, tz, **kwargs): """Construct 3d translation matrix. + Args: tx (float): X-direction translation amount. ty (float): Y-direction translation amount. @@ -155,6 +158,7 @@ def translate3d(tx, ty, tz, **kwargs): def scale2d(sx, sy, **kwargs): """Construct 2d scaling matrix. + Args: sx (float): X-direction scaling coefficient. sy (float): Y-direction scaling coefficient. @@ -167,6 +171,7 @@ def scale2d(sx, sy, **kwargs): def scale3d(sx, sy, sz, **kwargs): """Construct 3d scaling matrix. + Args: sx (float): X-direction scaling coefficient. sy (float): Y-direction scaling coefficient. @@ -181,6 +186,7 @@ def scale3d(sx, sy, sz, **kwargs): def rotate2d(theta, **kwargs): """Construct 2d rotating matrix. + Args: theta (float): Counter-clock wise rotation angle. Returns: @@ -193,6 +199,7 @@ def rotate2d(theta, **kwargs): def rotate3d(v, theta, **kwargs): """Constructing 3d rotating matrix. + Args: v (torch.Tensor): Luma axis. theta (float): Rotate theta counter-clock wise with ``v`` as the axis. @@ -215,6 +222,7 @@ def rotate3d(v, theta, **kwargs): def translate2d_inv(tx, ty, **kwargs): """Construct inverse matrix of 2d translation matrix. + Args: tx (float): X-direction translation amount. ty (float): Y-direction translation amount. @@ -227,6 +235,7 @@ def translate2d_inv(tx, ty, **kwargs): def scale2d_inv(sx, sy, **kwargs): """Construct inverse matrix of 2d scaling matrix. + Args: sx (float): X-direction scaling coefficient. sy (float): Y-direction scaling coefficient. @@ -239,6 +248,7 @@ def scale2d_inv(sx, sy, **kwargs): def rotate2d_inv(theta, **kwargs): """Construct inverse matrix of 2d rotating matrix. + Args: theta (float): Counter-clock wise rotation angle. Returns: @@ -259,9 +269,11 @@ def rotate2d_inv(theta, **kwargs): class AugmentPipe(torch.nn.Module): """Augmentation pipeline include multiple geometric and color transformations. + Note: The meaning of arguments are written in the comments of ``__init__`` function. """ + def __init__( self, xflip=0, @@ -388,9 +400,8 @@ def forward(self, images, debug_percentile=None): batch_size, num_channels, height, width = images.shape device = images.device if debug_percentile is not None: - debug_percentile = torch.as_tensor(debug_percentile, - dtype=torch.float32, - device=device) + debug_percentile = torch.as_tensor( + debug_percentile, dtype=torch.float32, device=device) # ------------------------------------- # Select parameters for pixel blitting. @@ -431,8 +442,8 @@ def forward(self, images, debug_percentile=None): if debug_percentile is not None: t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max) - G_inv = G_inv @ translate2d_inv(torch.round(t[:, 0] * width), - torch.round(t[:, 1] * height)) + G_inv = G_inv @ translate2d_inv( + torch.round(t[:, 0] * width), torch.round(t[:, 1] * height)) # -------------------------------------------------------- # Select parameters for general geometric transformations. @@ -537,40 +548,41 @@ def forward(self, images, debug_percentile=None): mx0, my0, mx1, my1 = margin.ceil().to(torch.int32) # Pad image and adjust origin. - images = torch.nn.functional.pad(input=images, - pad=[mx0, mx1, my0, my1], - mode='reflect') - G_inv = translate2d(torch.true_divide(mx0 - mx1, 2), - torch.true_divide(my0 - my1, 2)) @ G_inv + images = torch.nn.functional.pad( + input=images, pad=[mx0, mx1, my0, my1], mode='reflect') + G_inv = translate2d( + torch.true_divide(mx0 - mx1, 2), torch.true_divide( + my0 - my1, 2)) @ G_inv # Upsample. images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2) - G_inv = scale2d(2, 2, device=device) @ G_inv @ scale2d_inv( - 2, 2, device=device) - G_inv = translate2d(-0.5, -0.5, - device=device) @ G_inv @ translate2d_inv( - -0.5, -0.5, device=device) + G_inv = scale2d( + 2, 2, device=device) @ G_inv @ scale2d_inv( + 2, 2, device=device) + G_inv = translate2d( + -0.5, -0.5, device=device) @ G_inv @ translate2d_inv( + -0.5, -0.5, device=device) # Execute transformation. shape = [ batch_size, num_channels, (height + Hz_pad * 2) * 2, (width + Hz_pad * 2) * 2 ] - G_inv = scale2d(2 / images.shape[3], - 2 / images.shape[2], - device=device) @ G_inv @ scale2d_inv( - 2 / shape[3], 2 / shape[2], device=device) - grid = torch.nn.functional.affine_grid(theta=G_inv[:, :2, :], - size=shape, - align_corners=False) + G_inv = scale2d( + 2 / images.shape[3], 2 / images.shape[2], + device=device) @ G_inv @ scale2d_inv( + 2 / shape[3], 2 / shape[2], device=device) + grid = torch.nn.functional.affine_grid( + theta=G_inv[:, :2, :], size=shape, align_corners=False) images = grid_sample_gradfix.grid_sample(images, grid) # Downsample and crop. - images = upfirdn2d.downsample2d(x=images, - f=self.Hz_geom, - down=2, - padding=-Hz_pad * 2, - flip_filter=True) + images = upfirdn2d.downsample2d( + x=images, + f=self.Hz_geom, + down=2, + padding=-Hz_pad * 2, + flip_filter=True) # -------------------------------------------- # Select parameters for color transformations. @@ -611,8 +623,8 @@ def forward(self, images, debug_percentile=None): C = scale3d(c, c, c) @ C # Apply luma flip with probability (lumaflip * strength). - v = misc.constant(np.asarray([1, 1, 1, 0]) / np.sqrt(3), - device=device) # Luma axis. + v = misc.constant( + np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device) # Luma axis. if self.lumaflip > 0: i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2) i = torch.where( @@ -661,8 +673,8 @@ def forward(self, images, debug_percentile=None): images = C[:, :3, :3] @ images + C[:, :3, 3:] elif num_channels == 1: C = C[:, :3, :].mean(dim=1, keepdims=True) - images = images * C[:, :, :3].sum(dim=2, - keepdims=True) + C[:, :, 3:] + images = images * C[:, :, :3].sum( + dim=2, keepdims=True) + C[:, :, 3:] else: raise ValueError( 'Image must be RGB (3 channels) or L (1 channel)') @@ -716,15 +728,16 @@ def forward(self, images, debug_percentile=None): p = self.Hz_fbank.shape[1] // 2 images = images.reshape( [1, batch_size * num_channels, height, width]) - images = torch.nn.functional.pad(input=images, - pad=[p, p, p, p], - mode='reflect') - images = conv2d_gradfix.conv2d(input=images, - weight=Hz_prime.unsqueeze(2), - groups=batch_size * num_channels) - images = conv2d_gradfix.conv2d(input=images, - weight=Hz_prime.unsqueeze(3), - groups=batch_size * num_channels) + images = torch.nn.functional.pad( + input=images, pad=[p, p, p, p], mode='reflect') + images = conv2d_gradfix.conv2d( + input=images, + weight=Hz_prime.unsqueeze(2), + groups=batch_size * num_channels) + images = conv2d_gradfix.conv2d( + input=images, + weight=Hz_prime.unsqueeze(3), + groups=batch_size * num_channels) images = images.reshape([batch_size, num_channels, height, width]) # ------------------------ @@ -759,8 +772,8 @@ def forward(self, images, debug_percentile=None): size = torch.full_like(size, self.cutout_size) center = torch.full_like(center, debug_percentile) coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1]) - coord_y = torch.arange(height, - device=device).reshape([1, 1, -1, 1]) + coord_y = torch.arange( + height, device=device).reshape([1, 1, -1, 1]) mask_x = (((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2) mask_y = (((coord_y + 0.5) / height - center[:, 1]).abs() >= diff --git a/mmgen/models/architectures/stylegan/ada/upfirdn2d.py b/mmgen/models/architectures/stylegan/ada/upfirdn2d.py index 0d15167f1..51ead27ca 100644 --- a/mmgen/models/architectures/stylegan/ada/upfirdn2d.py +++ b/mmgen/models/architectures/stylegan/ada/upfirdn2d.py @@ -182,12 +182,8 @@ def downsample2d(x, if flip_filter: f = f.flip(list(range(f.ndim))) if f.ndim == 1: - x = upfirdn2d(x, - f.unsqueeze(0), - down=(downx, 1), - pad=(p[0], p[1], 0, 0)) - x = upfirdn2d(x, - f.unsqueeze(1), - down=(1, downy), - pad=(0, 0, p[2], p[3])) + x = upfirdn2d( + x, f.unsqueeze(0), down=(downx, 1), pad=(p[0], p[1], 0, 0)) + x = upfirdn2d( + x, f.unsqueeze(1), down=(1, downy), pad=(0, 0, p[2], p[3])) return x diff --git a/mmgen/models/architectures/stylegan/generator_discriminator_v2.py b/mmgen/models/architectures/stylegan/generator_discriminator_v2.py index ecd264023..784734404 100644 --- a/mmgen/models/architectures/stylegan/generator_discriminator_v2.py +++ b/mmgen/models/architectures/stylegan/generator_discriminator_v2.py @@ -665,8 +665,7 @@ def __init__(self, update_interval=4, augment_initial_p=0., ada_target=0.6, - ada_kimg=500, - use_slow_aug=False): + ada_kimg=500): super().__init__() self.aug_pipeline = AugmentPipe(**aug_pipeline) self.update_interval = update_interval diff --git a/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py b/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py index ffeeefad8..d62829756 100644 --- a/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py +++ b/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py @@ -472,7 +472,7 @@ def forward(self, x, w, force_fp32=False, update_emas=False): input magnitude. Defaults to False. Returns: - torch.Tensor: Output feature map tensor map. + torch.Tensor: Output feature map tensor. """ # Track input magnitude. diff --git a/tests/test_modules/test_stylev2_archs.py b/tests/test_modules/test_stylev2_archs.py index 9a3c83b17..7753d2714 100644 --- a/tests/test_modules/test_stylev2_archs.py +++ b/tests/test_modules/test_stylev2_archs.py @@ -5,7 +5,7 @@ import torch from mmgen.models.architectures.stylegan.generator_discriminator_v2 import ( - StyleGAN2Discriminator, StyleGANv2Generator) + ADAStyleGAN2Discriminator, StyleGAN2Discriminator, StyleGANv2Generator) from mmgen.models.architectures.stylegan.modules import (Blur, ModulatedStyleConv, ModulatedToRGB) @@ -441,6 +441,88 @@ def test_fp16_stylegan2_disc_cuda(self): assert score.dtype == torch.float32 +class TestADAStyleGAN2Discriminator: + + @classmethod + def setup_class(cls): + aug_kwargs = { + 'xflip': 1, + 'rotate90': 1, + 'xint': 1, + 'scale': 1, + 'rotate': 1, + 'aniso': 1, + 'xfrac': 1, + 'brightness': 1, + 'contrast': 1, + 'lumaflip': 1, + 'hue': 1, + 'saturation': 1 + } + cls.default_cfg = dict( + in_size=64, + input_bgr2rgb=True, + data_aug=dict( + type='ADAAug', + update_interval=2, + aug_pipeline=aug_kwargs, + ada_kimg=100)) + + def test_ada_stylegan2_disc_cpu(self): + d = ADAStyleGAN2Discriminator(**self.default_cfg) + img = torch.randn((2, 3, 64, 64)) + score = d(img) + assert score.shape == (2, 1) + + # test ada p update + curr_iter = 0 + batch_size = 2 + score = torch.tensor([1., 1.]) + d.ada_aug.log_buffer[0] += 2 + d.ada_aug.log_buffer[1] += score.sign().sum() + d.ada_aug.update(iteration=curr_iter, num_batches=batch_size) + assert d.ada_aug.aug_pipeline.p == 0. + + curr_iter += 1 + d.ada_aug.log_buffer[0] += 2 + d.ada_aug.log_buffer[1] += score.sign().sum() + d.ada_aug.update(iteration=curr_iter, num_batches=batch_size) + assert d.ada_aug.aug_pipeline.p == 4.0000e-05 + + # test with p=1. + d.aug_pipeline.p.copy_(torch.tensor(1.)) + img = torch.randn((2, 3, 64, 64)) + score = d(img) + assert score.shape == (2, 1) + + def test_ada_stylegan2_disc_cuda(self): + d = ADAStyleGAN2Discriminator(**self.default_cfg).cuda() + img = torch.randn((2, 3, 64, 64)).cuda() + score = d(img) + assert score.shape == (2, 1) + + # test ada p update + curr_iter = 0 + batch_size = 2 + score = torch.tensor([1., 1.]).cuda() + d.ada_aug.log_buffer[0] += 2 + d.ada_aug.log_buffer[1] += score.sign().sum() + d.ada_aug.update(iteration=curr_iter, num_batches=batch_size) + assert d.ada_aug.aug_pipeline.p == 0. + + curr_iter += 1 + d.ada_aug.log_buffer[0] += 2 + d.ada_aug.log_buffer[1] += score.sign().sum() + d.ada_aug.update(iteration=curr_iter, num_batches=batch_size) + assert d.ada_aug.aug_pipeline.p == 4.0000e-05 + + # test with p=1. + d.aug_pipeline.p.copy_(torch.tensor(1.)) + img = torch.randn((2, 3, 64, 64)).cuda() + score = d(img) + assert score.shape == (2, 1) + + class TestMSStyleGANv2Disc: @classmethod From 0abd7207d65225c0866dfee51c3ebe381f5e3285 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Thu, 31 Mar 2022 22:56:14 +0800 Subject: [PATCH 03/11] fix lint --- ...egan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py | 6 +++--- ...egan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py | 14 +++++++------- tests/test_modules/test_stylev2_archs.py | 4 ++-- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py b/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py index a6ec5be48..ea763be65 100644 --- a/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py +++ b/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py @@ -14,7 +14,7 @@ r1_gamma = 3.3 # set by user d_reg_interval = 16 -load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa +load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_r_ffhq_1024_b4x8_cvt_official_rgb_20220329_234933-ac0500a1.pth' # noqa # ada settings aug_kwargs = { @@ -57,7 +57,7 @@ ema_kimg = 10 ema_nimg = ema_kimg * 1000 -ema_beta = 0.5 ** (32 / max(ema_nimg, 1e-8)) +ema_beta = 0.5**(32 / max(ema_nimg, 1e-8)) custom_hooks = [ dict( @@ -86,7 +86,7 @@ inception_path = None # noqa evaluation = dict( type='GenerativeEvalHook', - interval=dict(milestones=[100000],interval=[10000, 5000]), + interval=dict(milestones=[100000], interval=[10000, 5000]), metrics=dict( type='FID', num_images=50000, diff --git a/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py b/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py index df5408b58..6185f624a 100644 --- a/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py +++ b/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py @@ -12,7 +12,7 @@ r1_gamma = 6.6 # set by user d_reg_interval = 16 -load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa +load_from = 'https://download.openmmlab.com/mmgen/stylegan3/stylegan3_t_ffhq_1024_b4x8_cvt_official_rgb_20220329_235113-db6c6580.pth' # noqa # ada settings aug_kwargs = { 'xflip': 1, @@ -54,13 +54,13 @@ ema_kimg = 10 ema_nimg = ema_kimg * 1000 -ema_beta = 0.5 ** (32 / max(ema_nimg, 1e-8)) +ema_beta = 0.5**(32 / max(ema_nimg, 1e-8)) custom_hooks = [ dict( - type='VisualizeUnconditionalSamples', - output_dir='training_samples', - interval=5000), + type='VisualizeUnconditionalSamples', + output_dir='training_samples', + interval=5000), dict( type='ExponentialMovingAverageHook', module_keys=('generator_ema', ), @@ -80,10 +80,10 @@ inception_args=dict(type='StyleGAN'), bgr2rgb=True)) -inception_path = None # set by user +inception_path = None # set by user evaluation = dict( type='GenerativeEvalHook', - interval=dict(milestones=[80000],interval=[10000, 5000]), + interval=dict(milestones=[80000], interval=[10000, 5000]), metrics=dict( type='FID', num_images=50000, diff --git a/tests/test_modules/test_stylev2_archs.py b/tests/test_modules/test_stylev2_archs.py index 7753d2714..3db1eb093 100644 --- a/tests/test_modules/test_stylev2_archs.py +++ b/tests/test_modules/test_stylev2_archs.py @@ -490,7 +490,7 @@ def test_ada_stylegan2_disc_cpu(self): assert d.ada_aug.aug_pipeline.p == 4.0000e-05 # test with p=1. - d.aug_pipeline.p.copy_(torch.tensor(1.)) + d.ada_aug.aug_pipeline.p.copy_(torch.tensor(1.)) img = torch.randn((2, 3, 64, 64)) score = d(img) assert score.shape == (2, 1) @@ -517,7 +517,7 @@ def test_ada_stylegan2_disc_cuda(self): assert d.ada_aug.aug_pipeline.p == 4.0000e-05 # test with p=1. - d.aug_pipeline.p.copy_(torch.tensor(1.)) + d.ada_aug.aug_pipeline.p.copy_(torch.tensor(1.)) img = torch.randn((2, 3, 64, 64)).cuda() score = d(img) assert score.shape == (2, 1) From b9e7ab5e66def5a7a2b8ad112b51184ae1cb8531 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Thu, 31 Mar 2022 23:20:18 +0800 Subject: [PATCH 04/11] skip a test without cuda --- tests/test_modules/test_stylev2_archs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modules/test_stylev2_archs.py b/tests/test_modules/test_stylev2_archs.py index 3db1eb093..a8bbb5589 100644 --- a/tests/test_modules/test_stylev2_archs.py +++ b/tests/test_modules/test_stylev2_archs.py @@ -495,6 +495,7 @@ def test_ada_stylegan2_disc_cpu(self): score = d(img) assert score.shape == (2, 1) + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') def test_ada_stylegan2_disc_cuda(self): d = ADAStyleGAN2Discriminator(**self.default_cfg).cuda() img = torch.randn((2, 3, 64, 64)).cuda() From 3e34219f87f26d1a14d12db61837c4a1913c564f Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Fri, 1 Apr 2022 15:42:25 +0800 Subject: [PATCH 05/11] fix as comment --- .../stylegan/modules/styleganv3_modules.py | 5 +- .../test_static_unconditional_gan.py | 66 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py b/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py index d62829756..0e7b321d9 100644 --- a/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py +++ b/mmgen/models/architectures/stylegan/modules/styleganv3_modules.py @@ -463,6 +463,7 @@ def __init__( def forward(self, x, w, force_fp32=False, update_emas=False): """Forward function for synthesis layer. + Args: x (torch.Tensor): Input feature map tensor. w (torch.Tensor): Input style tensor. @@ -523,7 +524,8 @@ def forward(self, x, w, force_fp32=False, update_emas=False): @staticmethod def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): - """Design lowpass filter giving related arguments., + """Design lowpass filter giving related arguments. + Args: numtaps (int): Length of the filter. `numtaps` must be odd if a passband includes the Nyquist frequency. @@ -532,6 +534,7 @@ def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False): fs (float): The sampling frequency of the signal. radial (bool, optional): Whether use radially symmetric jinc-based filter. Defaults to False. + Returns: torch.Tensor: Kernel of lowpass filter. """ diff --git a/tests/test_models/test_static_unconditional_gan.py b/tests/test_models/test_static_unconditional_gan.py index 22fa05edf..6aca1ca08 100644 --- a/tests/test_models/test_static_unconditional_gan.py +++ b/tests/test_models/test_static_unconditional_gan.py @@ -168,3 +168,69 @@ def test_default_dcgan_model_cuda(self): data_input, optim_dict, running_status=dict(iteration=1)) assert 'loss_disc_fake' in model_outputs['log_vars'] assert 'loss_disc_fake_g' in model_outputs['log_vars'] + + def test_ada_stylegan2_model_cpu(self): + synthesis_cfg = { + 'type': 'SynthesisNetwork', + 'channel_base': 32768, + 'channel_max': 512, + 'magnitude_ema_beta': 0.999 + } + aug_kwargs = { + 'xflip': 1, + 'rotate90': 1, + 'xint': 1, + 'scale': 1, + 'rotate': 1, + 'aniso': 1, + 'xfrac': 1, + 'brightness': 1, + 'contrast': 1, + 'lumaflip': 1, + 'hue': 1, + 'saturation': 1 + } + default_config = dict( + type='StaticUnconditionalGAN', + generator=dict( + type='StyleGANv3Generator', + out_size=32, + style_channels=8, + img_channels=3, + rgb2bgr=True, + synthesis_cfg=synthesis_cfg), + discriminator=dict( + type='ADAStyleGAN2Discriminator', + in_size=32, + input_bgr2rgb=True, + data_aug=dict( + type='ADAAug', + update_interval=2, + aug_pipeline=aug_kwargs, + ada_kimg=100)), + gan_loss=dict(type='GANLoss', gan_type='wgan-logistic-ns')) + + s3gan = build_model(default_config) + assert isinstance(s3gan, StaticUnconditionalGAN) + assert not s3gan.with_disc_auxiliary_loss + assert s3gan.with_disc + + # test forward train + with pytest.raises(NotImplementedError): + _ = s3gan(None, return_loss=True) + # test forward test + imgs = s3gan(None, return_loss=False, mode='sampling', num_batches=2) + assert imgs.shape == (2, 3, 32, 32) + + # test train step + data = torch.randn((2, 3, 32, 32)) + data_input = dict(real_img=data) + optimizer_g = torch.optim.SGD(s3gan.generator.parameters(), lr=0.01) + optimizer_d = torch.optim.SGD( + s3gan.discriminator.parameters(), lr=0.01) + optim_dict = dict(generator=optimizer_g, discriminator=optimizer_d) + + _ = s3gan.train_step( + data_input, optim_dict, running_status=dict(iteration=1)) + _ = s3gan.train_step(data_input, optim_dict) + s3gan.discriminator.ada_aug.aug_pipeline.p.dtype == torch.float32 From b46518789bc17a57ecf0ac3d464605c01992d745 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Fri, 1 Apr 2022 15:59:51 +0800 Subject: [PATCH 06/11] lower ut mem --- tests/test_models/test_static_unconditional_gan.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/test_models/test_static_unconditional_gan.py b/tests/test_models/test_static_unconditional_gan.py index 6aca1ca08..4f6b20b45 100644 --- a/tests/test_models/test_static_unconditional_gan.py +++ b/tests/test_models/test_static_unconditional_gan.py @@ -194,14 +194,14 @@ def test_ada_stylegan2_model_cpu(self): type='StaticUnconditionalGAN', generator=dict( type='StyleGANv3Generator', - out_size=32, + out_size=16, style_channels=8, img_channels=3, rgb2bgr=True, synthesis_cfg=synthesis_cfg), discriminator=dict( type='ADAStyleGAN2Discriminator', - in_size=32, + in_size=16, input_bgr2rgb=True, data_aug=dict( type='ADAAug', @@ -220,10 +220,10 @@ def test_ada_stylegan2_model_cpu(self): _ = s3gan(None, return_loss=True) # test forward test imgs = s3gan(None, return_loss=False, mode='sampling', num_batches=2) - assert imgs.shape == (2, 3, 32, 32) + assert imgs.shape == (2, 3, 16, 16) # test train step - data = torch.randn((2, 3, 32, 32)) + data = torch.randn((2, 3, 16, 16)) data_input = dict(real_img=data) optimizer_g = torch.optim.SGD(s3gan.generator.parameters(), lr=0.01) optimizer_d = torch.optim.SGD( @@ -232,5 +232,4 @@ def test_ada_stylegan2_model_cpu(self): _ = s3gan.train_step( data_input, optim_dict, running_status=dict(iteration=1)) - _ = s3gan.train_step(data_input, optim_dict) s3gan.discriminator.ada_aug.aug_pipeline.p.dtype == torch.float32 From 74da2240a4486dd6ea185f66d7ff994a16c5d065 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Fri, 1 Apr 2022 17:16:58 +0800 Subject: [PATCH 07/11] fix ci --- tests/test_models/test_static_unconditional_gan.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_models/test_static_unconditional_gan.py b/tests/test_models/test_static_unconditional_gan.py index 4f6b20b45..477396429 100644 --- a/tests/test_models/test_static_unconditional_gan.py +++ b/tests/test_models/test_static_unconditional_gan.py @@ -169,11 +169,13 @@ def test_default_dcgan_model_cuda(self): assert 'loss_disc_fake' in model_outputs['log_vars'] assert 'loss_disc_fake_g' in model_outputs['log_vars'] + @pytest.mark.skipif( + torch.__version__ in ['1.5.1', '1.7.0'], reason='avoid killing') def test_ada_stylegan2_model_cpu(self): synthesis_cfg = { 'type': 'SynthesisNetwork', - 'channel_base': 32768, - 'channel_max': 512, + 'channel_base': 1024, + 'channel_max': 16, 'magnitude_ema_beta': 0.999 } aug_kwargs = { @@ -194,14 +196,14 @@ def test_ada_stylegan2_model_cpu(self): type='StaticUnconditionalGAN', generator=dict( type='StyleGANv3Generator', - out_size=16, + out_size=8, style_channels=8, img_channels=3, rgb2bgr=True, synthesis_cfg=synthesis_cfg), discriminator=dict( type='ADAStyleGAN2Discriminator', - in_size=16, + in_size=8, input_bgr2rgb=True, data_aug=dict( type='ADAAug', @@ -220,10 +222,10 @@ def test_ada_stylegan2_model_cpu(self): _ = s3gan(None, return_loss=True) # test forward test imgs = s3gan(None, return_loss=False, mode='sampling', num_batches=2) - assert imgs.shape == (2, 3, 16, 16) + assert imgs.shape == (2, 3, 8, 8) # test train step - data = torch.randn((2, 3, 16, 16)) + data = torch.randn((2, 3, 8, 8)) data_input = dict(real_img=data) optimizer_g = torch.optim.SGD(s3gan.generator.parameters(), lr=0.01) optimizer_d = torch.optim.SGD( From 18306e0ef01101a3a6bb0f85f54f5740711a864b Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Fri, 1 Apr 2022 18:25:34 +0800 Subject: [PATCH 08/11] fix ci --- tests/test_models/test_static_unconditional_gan.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_models/test_static_unconditional_gan.py b/tests/test_models/test_static_unconditional_gan.py index 477396429..c4c515390 100644 --- a/tests/test_models/test_static_unconditional_gan.py +++ b/tests/test_models/test_static_unconditional_gan.py @@ -169,8 +169,7 @@ def test_default_dcgan_model_cuda(self): assert 'loss_disc_fake' in model_outputs['log_vars'] assert 'loss_disc_fake_g' in model_outputs['log_vars'] - @pytest.mark.skipif( - torch.__version__ in ['1.5.1', '1.7.0'], reason='avoid killing') + @pytest.mark.skipif(torch.__version__ in ['1.5.1'], reason='avoid killing') def test_ada_stylegan2_model_cpu(self): synthesis_cfg = { 'type': 'SynthesisNetwork', From 1a6e142a2d1fe241d6b6c08a8f0f6fbe4fdacbdd Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Fri, 1 Apr 2022 21:44:53 +0800 Subject: [PATCH 09/11] fix as comment --- .../stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py | 3 +-- .../stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py | 3 +-- .../stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py | 3 +-- mmgen/core/evaluation/metrics.py | 3 ++- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py b/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py index ea763be65..06784cee5 100644 --- a/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py +++ b/configs/styleganv3/stylegan3_r_ada_fp16_gamma3.3_metfaces_1024_b4x8.py @@ -83,7 +83,6 @@ inception_args=dict(type='StyleGAN'), bgr2rgb=True)) -inception_path = None # noqa evaluation = dict( type='GenerativeEvalHook', interval=dict(milestones=[100000], interval=[10000, 5000]), @@ -91,7 +90,7 @@ type='FID', num_images=50000, inception_pkl=inception_pkl, - inception_args=dict(type='StyleGAN', inception_path=inception_path), + inception_args=dict(type='StyleGAN'), bgr2rgb=True), sample_kwargs=dict(sample_model='ema')) diff --git a/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py b/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py index 6185f624a..07f94d4fc 100644 --- a/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py +++ b/configs/styleganv3/stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8.py @@ -80,7 +80,6 @@ inception_args=dict(type='StyleGAN'), bgr2rgb=True)) -inception_path = None # set by user evaluation = dict( type='GenerativeEvalHook', interval=dict(milestones=[80000], interval=[10000, 5000]), @@ -88,7 +87,7 @@ type='FID', num_images=50000, inception_pkl=inception_pkl, - inception_args=dict(type='StyleGAN', inception_path=inception_path), + inception_args=dict(type='StyleGAN'), bgr2rgb=True), sample_kwargs=dict(sample_model='ema')) diff --git a/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py b/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py index 33a5c36f4..100cf3a6b 100644 --- a/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py +++ b/configs/styleganv3/stylegan3_t_noaug_fp16_gamma32.8_ffhq_1024_b4x8.py @@ -56,7 +56,6 @@ inception_args=dict(type='StyleGAN'), bgr2rgb=True)) -inception_path = None # set by user evaluation = dict( type='GenerativeEvalHook', interval=10000, @@ -64,7 +63,7 @@ type='FID', num_images=50000, inception_pkl=inception_pkl, - inception_args=dict(type='StyleGAN', inception_path=inception_path), + inception_args=dict(type='StyleGAN'), bgr2rgb=True), sample_kwargs=dict(sample_model='ema')) diff --git a/mmgen/core/evaluation/metrics.py b/mmgen/core/evaluation/metrics.py index 9637fdf3e..40e06b669 100644 --- a/mmgen/core/evaluation/metrics.py +++ b/mmgen/core/evaluation/metrics.py @@ -1291,7 +1291,8 @@ def get_sampler(self, model, batch_size, sample_model): model (nn.Module): Generative model. batch_size (int): Sampling batch size. sample_model (str): Which model you want to use. ['ema', - 'orig']. Defaults to 'ema'. + 'orig']. Defaults to 'ema'. + Returns: Object: A sampler for calculating path length regularization. """ From 6297aae945780f1aaf45c2501d7ef46dab072e19 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Sat, 2 Apr 2022 00:09:19 +0800 Subject: [PATCH 10/11] fix lint --- configs/styleganv3/metafile.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/styleganv3/metafile.yml b/configs/styleganv3/metafile.yml index 2a667d769..36ec52aa9 100755 --- a/configs/styleganv3/metafile.yml +++ b/configs/styleganv3/metafile.yml @@ -11,7 +11,7 @@ Models: In Collection: StyleGANv3 Metadata: Training Data: FFHQ - Name: stylegan3_gamma32.8 + Name: stylegan3_noaug Results: - Dataset: FFHQ Metrics: @@ -24,7 +24,7 @@ Models: In Collection: StyleGANv3 Metadata: Training Data: Others - Name: stylegan3_t_ada_fp16_gamma6.6_metfaces_1024_b4x8 + Name: stylegan3_ada Results: - Dataset: Others Metrics: @@ -37,7 +37,7 @@ Models: In Collection: StyleGANv3 Metadata: Training Data: FFHQ - Name: stylegan3_gamma2.0 + Name: stylegan3_t Results: - Dataset: FFHQ Metrics: From 68ab951d18b3ede1a76e739427cb95d548308ee5 Mon Sep 17 00:00:00 2001 From: plyfager <2744335995@qq.com> Date: Sat, 2 Apr 2022 08:49:32 +0800 Subject: [PATCH 11/11] test s3r --- tests/test_modules/test_stylev3_archs.py | 35 ++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_modules/test_stylev3_archs.py b/tests/test_modules/test_stylev3_archs.py index 12554218e..b883b9331 100644 --- a/tests/test_modules/test_stylev3_archs.py +++ b/tests/test_modules/test_stylev3_archs.py @@ -209,6 +209,20 @@ def setup_class(cls): out_size=16, img_channels=3, synthesis_cfg=synthesis_cfg) + synthesis_r_cfg = { + 'type': 'SynthesisNetwork', + 'channel_base': 1024, + 'channel_max': 16, + 'magnitude_ema_beta': 0.999, + 'conv_kernel': 1, + 'use_radial_filters': True + } + cls.s3_r_cfg = dict( + noise_size=6, + style_channels=8, + out_size=16, + img_channels=3, + synthesis_cfg=synthesis_r_cfg) def test_cpu(self): generator = StyleGANv3Generator(**self.default_cfg) @@ -244,6 +258,18 @@ def test_cpu(self): assert result['noise_batch'].shape == (2, 8) assert result['latent'].shape == (2, 16, 8) + generator = StyleGANv3Generator(**self.s3_r_cfg) + z = torch.randn((2, 6)) + c = None + y = generator(z, c) + assert y.shape == (2, 3, 16, 16) + + y = generator(None, num_batches=2) + assert y.shape == (2, 3, 16, 16) + + res = generator(torch.randn, num_batches=1) + assert res.shape == (1, 3, 16, 16) + @pytest.mark.skipif(not torch.cuda.is_available(), reason='requires cuda') def test_cuda(self): generator = StyleGANv3Generator(**self.default_cfg).cuda() @@ -260,3 +286,12 @@ def test_cuda(self): generator = StyleGANv3Generator(**cfg).cuda() y = generator(None, num_batches=2) assert y.shape == (2, 3, 16, 16) + + generator = StyleGANv3Generator(**self.s3_r_cfg).cuda() + z = torch.randn((2, 6)).cuda() + c = None + y = generator(z, c) + assert y.shape == (2, 3, 16, 16) + + res = generator(torch.randn, num_batches=1) + assert res.shape == (1, 3, 16, 16)