This is an unofficial JAX implementation for A Sliced Wasserstein Loss for Neural Texture Synthesis (CVPR'21).
Please see here for the author's repository and cite them:
@InProceedings{Heitz_2021_CVPR,
author = {Heitz, Eric and Vanhoey, Kenneth and Chambon, Thomas and Belcour, Laurent},
title = {A Sliced Wasserstein Loss for Neural Texture Synthesis},
booktitle = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021}
}
We require these libraries:
pip install -U "jax[cuda]" equinox optax tqdm pillow
The pre-trained VGG weights vgg19.npy
is ported from the vgg19.pth
file provided in the official repo.
We re-write the VGG network and Slice Wasserstein Loss in JAX code.
python texsyn.py --exemplar_path data/input.png --loss_type sw
Input | Output (Slice) | Output (Gram) |
---|---|---|
Thanks all efforts put on making all mentioned repositories public.
We appreciate bug reports. I will fix them when I make time around.