Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
39753b6
StyleGAN-ada on full layer now working & added support for grayscale …
McFredward Jun 7, 2022
f22e490
Update README.md
McFredward Jun 7, 2022
bfcb9f4
Added stylegan2-ada submodule
McFredward Jun 7, 2022
ac3e6da
Merge branch 'master' of https://github.com/McFredward/ganspace
McFredward Jun 7, 2022
5b86faa
Added stylegan2-ada
McFredward Jun 7, 2022
25208dd
Added custom args & added mapping-layer to decomposition
McFredward Jun 7, 2022
7e54d05
Update README.md
McFredward Jun 7, 2022
0ddc3c4
Update README.md
McFredward Jun 7, 2022
7996f69
Update README.md
McFredward Jun 7, 2022
e034c5a
Update README.md
McFredward Jun 7, 2022
c6dce71
changes back to old implementation
McFredward Jun 7, 2022
710b903
deleted old code-blocks and comments
McFredward Jun 7, 2022
944e30e
Fixed invalid shape bug for grayscale images in google colab
McFredward Jun 8, 2022
7b765ec
Fixed bug if w-space is used
McFredward Jun 8, 2022
0a8b683
fixed --use_w is not working
McFredward Jun 8, 2022
e874e80
removed debug prints
McFredward Jun 8, 2022
c620c61
Update README.md
McFredward Jun 8, 2022
470a3c6
Added scatter block of first two PCs
McFredward Jun 9, 2022
b79d1ce
Added new argument for num of samples in the scatter plot
McFredward Jun 9, 2022
924b98a
fix tensor must be on cpu
McFredward Jun 9, 2022
58dcdf7
fix tensor must be on cpu 2
McFredward Jun 9, 2022
c1b129a
Add more options to controll the scatter plot
McFredward Jun 10, 2022
e2238f0
fixed cpu bug
McFredward Jun 10, 2022
cbc61f9
fixed cpu bug (hopefully)
McFredward Jun 10, 2022
e4c6721
Downscale scatter images to avoid RAM overflow
McFredward Jun 10, 2022
fc3f78b
scatter image plots now based on global mean
McFredward Jun 10, 2022
4a640b5
scattor now also works with --use_w
McFredward Jun 10, 2022
5472bc9
cpu fix
McFredward Jun 10, 2022
625bc86
fixed color images not working
McFredward Jun 10, 2022
17ddf2f
Added mean and sigma ellipse to scatter plot
McFredward Jun 10, 2022
db9d54b
Merge branch 'master' of https://github.com/McFredward/ganspace
McFredward Jun 10, 2022
d83e130
Added scatter preview
McFredward Jun 10, 2022
f7348b4
Update README.md
McFredward Jun 10, 2022
5b7e24c
Update README.md
McFredward Jun 10, 2022
5ecc1fc
Update README.md
McFredward Jun 10, 2022
cd23bc8
Keep axis ratio in scatter plot
McFredward Jun 21, 2022
7232ccc
Merge branch 'master' of https://github.com/McFredward/ganspace
McFredward Jun 21, 2022
20cbc3c
Equal axis for better latent space analysis
McFredward Jun 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@
path = models/stylegan2/stylegan2-pytorch
url = https://github.com/harskish/stylegan2-pytorch.git
ignore = untracked

[submodule "stylegan2_ada/stylegan2-ada-pytorch"]
path = models/stylegan2_ada/stylegan2-ada-pytorch
url = https://github.com/NVlabs/stylegan2-ada-pytorch.git
ignore = untracked
39 changes: 39 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
# Changes compared to the original repo
* **Added StyleGAN2-ada support** <br />
The following classes for StyleGAN2-ada are available for automatic download:
* `ffhq`
* `afhqcat`
* `afhqdog`
* `afhqwild`
* `brecahad`
* `cifar10`
* `metfaces`

For a custom class add the name and the resolution in the `configs` dictonary in `models/wrappers.py` in the `StyleGAN2_ada` constructor and place the checkpoint-file at `models/checkpoints/stylegan2_ada/stylegan2_{class_name}_{resolution}.pkl` (replace {class_name} and {resolution} with the ones you added to the `configs` dict.)

`partial_forward` for StyleGAN2-ada is currently not fully implemented, which means if you use a layer in the synthesis network as activation space, it could take longer than with other models, since the complete foreward-pass is always computeted, even if the used layer is located somewhere earlier.
* **Added grayscale image support**
* **Added another progress bar during the creation of the images**
* **Added new args for `visualize.py` to control the outcome without changing the code:**

argument | description | arg-datatype
--- | --- | ---
`--plot_directions` | Number of components/directions to plot |int
`--plot_images` | Number of images per component/direction to plot | int
`--video_directions` | Number of components/directions to create a video of | int
`--video_images` | Number of frames within a video of one direction/component | int
* **Added interactive 2D scatter plot of the used activation space:**

<img src="StyleGAN_scatter.png" width=75% height=75%>

argument | description | arg-datatype
--- | --- | ---
`--scatter` | Activate scatter-plot | -
`--scatter_images` | Activate plotting corresponding generated images for each point | -
`--scatter_samples` | Number of samples in the 2D scatter plot | int
`--scatter_x` | Number of principal component for x-axis in the scatter plot | int
`--scatter_y` | Number of principal component for y-axis in the scatter plot | int

If `--scatter_images` is active, the interactive plot is saved as `.pickle` which can be opened with `python open_scatter.py [path]`.


# GANSpace: Discovering Interpretable GAN Controls
![Python 3.7](https://img.shields.io/badge/python-3.7-green.svg)
![PyTorch 1.3](https://img.shields.io/badge/pytorch-1.3-green.svg)
Expand Down
Binary file added StyleGAN_scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 22 additions & 6 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def __str__(self):
for k, v in self.__dict__.items():
if k == 'default_args':
continue

in_default = k in self.default_args
same_value = self.default_args.get(k) == v

if in_default and same_value:
default[k] = v
else:
Expand All @@ -42,15 +42,15 @@ def __str__(self):
}

return json.dumps(config, indent=4)

def __repr__(self):
return self.__str__()

def from_dict(self, dictionary):
for k, v in dictionary.items():
setattr(self, k, v)
return self

def from_args(self, args=sys.argv[1:]):
parser = argparse.ArgumentParser(description='GAN component analysis config')
parser.add_argument('--model', dest='model', type=str, default='StyleGAN', help='The network to analyze') # StyleGAN, DCGAN, ProGAN, BigGAN-XYZ
Expand All @@ -67,6 +67,22 @@ def from_args(self, args=sys.argv[1:]):
parser.add_argument('--sigma', type=float, default=2.0, help='Number of stdevs to walk in visualize.py')
parser.add_argument('--inputs', type=str, default=None, help='Path to directory with named components')
parser.add_argument('--seed', type=int, default=None, help='Seed used in decomposition')
parser.add_argument('--plot_directions', dest='np_directions', type=int, default=14, help='Number of components/directions to plot')
parser.add_argument('--plot_images', dest='np_images', type=int, default=5, help='Number of images per component/direction to plot')
parser.add_argument('--video_directions', dest='nv_images', type=int, default=5, help='Number of components/directions to create a video of')
parser.add_argument('--video_images', dest='nv_images', type=int, default=150, help='Number of frames within a video of one direction/component')
parser.add_argument('--scatter', dest='show_scatter', action='store_true', help='Plot a 2D scatter-plot of the activation space of two principal components')
parser.add_argument('--scatter_samples', dest='scatter_samples', type=int, default=1000, help='Number of samples in the 2D scatter plot of the first two principal components')
parser.add_argument('--scatter_images', dest='scatter_images', action='store_true', help='Plot encoded images instead of points within the scatter plot')
parser.add_argument('--scatter_x', dest='scatter_x_axis_pc', type=int, default=1, help='Number of PC for x-axis in the scatter plot')
parser.add_argument('--scatter_y', dest='scatter_y_axis_pc', type=int, default=2, help='Number of PC for y-axis in the scatter plot')


args = parser.parse_args(args)
assert args.np_images % 2 != 0, 'The number of plotted images per component (--plot_images) have to be odd.'

if(args.model == "StyleGAN2-ada" and args.layer == "g_mapping"):
print("No layer \'g_mapping\' in StyleGAN2-ada. Assuming you meant \'mapping\'")
args.layer = "mapping"

return self.from_dict(args.__dict__)
return self.from_dict(args.__dict__)
61 changes: 44 additions & 17 deletions decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config):
n_samp = max(10_000, config.n) // B * B # make divisible
n_comp = comp.shape[0]
latent_dims = inst.model.get_latent_dims()

# We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z
# Z = batch of latent vectors (n_samples x latent_dims)
# G'(Z) = batch of activations at intermediate layer
Expand All @@ -104,7 +104,7 @@ def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config):
# Dimensions other way around, so these are actually the transposes
A = np.zeros((n_samp, n_comp), dtype=np.float32)
Z = np.zeros((n_samp, latent_dims), dtype=np.float32)

# Project tensor X onto PCs, return coordinates
def project(X, comp):
N = X.shape[0]
Expand All @@ -131,7 +131,7 @@ def project(X, comp):
# gelsy = complete orthogonal factorization; sometimes faster
# gelss = SVD; slow but less memory hungry
M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :]

# Solution given by rows of M_t
Z_comp = M_t[:n_comp, :]
Z_mean = np.mean(Z, axis=0, keepdims=True)
Expand Down Expand Up @@ -182,6 +182,12 @@ def compute(config, dump_name, instrumented_model):
inst.retain_layer(layer_key)
model.partial_forward(model.sample_latent(1), layer_key)
sample_shape = inst.retained_features()[layer_key].shape

#StyleGAN2-ada's mapping networks copies it's result 18 times to [B,18,512] so the sample shape is different than the latent shape
#from wrapper, because it only returns [B,512], so that GANSpace can modify only specifc Style-Layers
if(model.model_name == "StyleGAN2_ada" and model.w_primary):
sample_shape = (sample_shape[0],sample_shape[2])

sample_dims = np.prod(sample_shape)
print('Feature shape:', sample_shape)

Expand Down Expand Up @@ -218,7 +224,7 @@ def compute(config, dump_name, instrumented_model):

# Must not depend on chosen batch size (reproducibility)
NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible!

samples = None
if not transformer.batch_support:
samples = np.zeros((N + NB, sample_dims), dtype=np.float32)
Expand All @@ -236,7 +242,7 @@ def compute(config, dump_name, instrumented_model):
latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy()

# Decomposition on non-Gaussian latent space
samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W'
samples_are_latents = layer_key in ['g_mapping', 'mapping', 'style'] and inst.model.latent_space_name() == 'W'

canceled = False
try:
Expand All @@ -245,15 +251,15 @@ def compute(config, dump_name, instrumented_model):
for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True):
for mb in range(0, NB, B):
z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device)

if samples_are_latents:
# Decomposition on latents directly (e.g. StyleGAN W)
batch = z.reshape((B, -1))
else:
# Decomposition on intermediate layer
with torch.no_grad():
model.partial_forward(z, layer_key)

# Permuted to place PCA dimensions last
batch = inst.retained_features()[layer_key].reshape((B, -1))

Expand All @@ -268,21 +274,21 @@ def compute(config, dump_name, instrumented_model):
except KeyboardInterrupt:
if not transformer.batch_support:
sys.exit(1) # no progress yet

dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}')
print(f'Saving current state to "{dump_name.name}" before exiting')
canceled = True

if not transformer.batch_support:
X = samples # Use all samples
X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...!
X -= X_global_mean

print(f'[{timestamp()}] Fitting whole batch')
t_start_fit = datetime.datetime.now()

transformer.fit(X)

print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}')
assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero'
else:
Expand All @@ -291,7 +297,7 @@ def compute(config, dump_name, instrumented_model):
X -= X_global_mean

X_comp, X_stdev, X_var_ratio = transformer.get_components()

assert X_comp.shape[1] == sample_dims \
and X_comp.shape[0] == config.components \
and X_global_mean.shape[1] == sample_dims \
Expand Down Expand Up @@ -349,6 +355,7 @@ def compute(config, dump_name, instrumented_model):
del inst
del model


del X
del X_comp
del random_dirs
Expand All @@ -363,20 +370,20 @@ def get_or_compute(config, model=None, submit_config=None, force_recompute=False
if submit_config is None:
wrkdir = str(Path(__file__).parent.resolve())
submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir)

# Called directly by run.py
return _compute(submit_config, config, model, force_recompute)

def _compute(submit_config, config, model=None, force_recompute=False):
basedir = Path(submit_config.run_dir)
outdir = basedir / 'out'

if config.n is None:
raise RuntimeError('Must specify number of samples with -n=XXX')

if model and not isinstance(model, InstrumentedModel):
raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"')

if config.use_w and not 'StyleGAN' in config.model:
raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}')

Expand All @@ -398,5 +405,25 @@ def _compute(submit_config, config, model=None, force_recompute=False):
t_start = datetime.datetime.now()
compute(config, dump_path, model)
print('Total time:', datetime.datetime.now() - t_start)

return dump_path

return dump_path



def imscatter(x, y, image, ax=None, zoom=1):
if ax is None:
ax = plt.gca()
try:
image = plt.imread(image)
except TypeError:
# Likely already an array...
pass
im = OffsetImage(image, zoom=zoom)
x, y = np.atleast_1d(x, y)
artists = []
for x0, y0 in zip(x, y):
ab = AnnotationBbox(im, (x0, y0), xycoords='data', frameon=False)
artists.append(ax.add_artist(ab))
ax.update_datalim(np.column_stack([x, y]))
ax.autoscale()
return artists
18 changes: 18 additions & 0 deletions models/stylegan2_ada/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import sys
import os
import shutil
import glob
import platform
from pathlib import Path

current_path = os.getcwd()

module_path = Path(__file__).parent / 'stylegan2-ada-pytorch'
sys.path.append(str(module_path.resolve()))
os.chdir(module_path)

import generate
import legacy
import dnnlib

os.chdir(current_path)
1 change: 1 addition & 0 deletions models/stylegan2_ada/stylegan2-ada-pytorch
Submodule stylegan2-ada-pytorch added at 6f160b
Loading