This repo contains the Keras code for sketch-rnn
. You can refer to these excellent blog posts and the paper by David Ha if you are keen to gain deeper insights. I also provide a Jupyter notebook example for the demonstration.
And also you can find the github repo in Magenta project
.
First, you will need to install git
, if you do not have it yet.
Next, clone this repository by opening a terminal and typing the following commands:
$ cd $HOME # or any other development directory you prefer
$ git clone https://github.com/KKeishiro/Sketch-RNN.git
$ cd Sketch-RNN
If you do not want to install git, you can instead download this repo.
-
Tensorflow 1.11.0
-
Keras 2.2.2
-
Python 3.6
Even though you can find several datasets in data
folder, I provide the pre-trained model weights only for owl dataset.
Here are some notes: The type of RNN cell is limited to LSTM, even though in the original implementation, you can also use LSTM cell with Layer Normalization and HyperLSTM. And also annealing the KL loss term is not implemented.
python train.py --data_dir=dataset_path --log_root=checkpoint_path [--resume_training --weights=weights_path]
For example,
python train.py --log_root=models/elephant
You can find some sketches are quite recognizable as an owl, but at the same time you can also tell that some are far from an owl. For instance, the second one from the left has four eyes, and the second one from the right is just composed by circles overlapping each other.
I assume this behavior is caused by the lack of training time. However, it also can be considered that the lack of variety of RNN cell and unexecuted KL loss annealing have a considerable impact on the sampling result.