Skip to content

Commit

Permalink
New features & bug fixes (#5)
Browse files Browse the repository at this point in the history
**Bug Fixes:**
* Resolved confusion between BGR and RGB usage. Some custom models may need to be retrained.
* Visualization scripts can now load all 3 starter models without errors.
* The `-seed` parameter in the `train_googlenet.py` script should work more effectively now.

**Changes:**

* The `train_googlenet.py` script now saves mean and standard deviation values in BGR format rather than RGB format.
* The visualization scripts now expect mean and standard deviation values in BGR format rather than RGB format.
* The `calc_ms.py` script now outputs normalization values in BGR format rather than RGB.
* The `vis_fc.py` script has been replaced with `vis_multi.py`. 

**New Features:**
* The new `vis_multi.py` script lets you visualize all channels in any specified layer, and lets you select a visualization batch size among other new features.
* Added new tool to edit model file values, called `edit_model.py`.
* Added new normalization value format attribute to models for easier handling of BGR and RGB models.
* Added  `-use_rgb` parameter to the `calc_ms.py` script for if you want the original behavior. 

---


To update your old models to the new correct format:


```
python data_tools/edit_model.py -model_file <your-model.pth> -normval_format bgr -reverse_normvals -output_name <updated-model.pth>
```
  • Loading branch information
ProGamerGov authored Sep 15, 2020
1 parent fdba70d commit b912f21
Show file tree
Hide file tree
Showing 14 changed files with 396 additions and 156 deletions.
4 changes: 2 additions & 2 deletions INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ val Loss: 0.6986 Acc: 0.5125
After training has finished, we can now visualize the newly created DeepDream model's FC layers using the following command:

```
python vis_fc.py -model_file bvlc_out010.pth -model_epoch 10 -num_iterations 200
python vis_multi.py -model_file bvlc_out010.pth -num_iterations 200
```

The `vis_fc.py` script should end up creating two output images, where one image has more circlelike features and the other has more squarelike features. Using more complex datasets that have more classes and images will yield far better looking results. You can find a list of image collection tools, possible sources of images, and duplicate image detection tools on the [dream-creator wiki](https://github.com/ProGamerGov/dream-creator/wiki).
The `vis_multi.py` script should end up creating two output images, where one image has more circlelike features and the other has more squarelike features. Using more complex datasets that have more classes and images will yield far better looking results. You can find a list of image collection tools, possible sources of images, and duplicate image detection tools on the [dream-creator wiki](https://github.com/ProGamerGov/dream-creator/wiki).

Finally, to visualize a single layer and channel or to DeepDream your own image, we can use the following command:

Expand Down
17 changes: 13 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,21 @@ python train_googlenet.py -data_path <training_data> -balance_classes -batch_siz

## Visualizing GoogleNet FC Layer Results

After training a new DeepDream model, you'll need to test it's visualizations. The best visualizations are found in the main FC layer also known as the 'logits' layer. This script helps you quickly and easily visualize all FC layer channels for a particular model epoch, by generating a separate image for each channel.
After training a new DeepDream model, you'll need to test it's visualizations. The best visualizations are found in the main FC layer also known as the 'logits' layer. This script helps you quickly and easily visualize all of a specified layer's channels in a particular model for a particular model epoch, by generating a separate image for each channel.

**Input options:**
* `-model_file`: Path to the pretrained GoogleNet model that you wish to use.
* `-learning_rate`: Learning rate to use with the ADAM or L-BFGS optimizer. Default is `1.5`.
* `-optimizer`: The optimization algorithm to use; either `lbfgs` or `adam`; default is `adam`.
* `-num_iterations`: Default is 500.
* `-num_iterations`: Default is `500`.
* `-layer`: The specific layer you wish to use. Default is set to `fc`.
* `-image_size`: A comma separated list of `<height>,<width>` to use for the output image. Default is set to `224,224`.
* `-jitter`: The amount of image jitter to use for preprocessing. Default is `32`.

**Processing options:**
* `-batch_size`: How many channel visualization images to create in each batch. Default is `10`.
* `-start_channel`: What channel to start creating visualization images at. Default is `0`.
* `-end_channel`: What channel to stop creating visualization images at. Default is set to `-1` for all channels.

**Only Required If Model Doesn't Contain Them, Options**:
* `-model_epoch`: The training epoch that the model was saved from, to use for the output image names. Default is `120`.
Expand All @@ -194,7 +202,7 @@ After training a new DeepDream model, you'll need to test it's visualizations. T
Basic FC (logits) layer visualization:

```
python vis_fc.py -model_file <bvlc_out120>.pth
python vis_multi.py -model_file <bvlc_out120>.pth
```

---
Expand All @@ -207,11 +215,12 @@ This script lets you create DeepDream hallucinations with trained GoogleNet mode
* `-model_file`: Path to the pretrained GoogleNet model that you wish to use.
* `-learning_rate`: Learning rate to use with the ADAM or L-BFGS optimizer. Default is `1.5`.
* `-optimizer`: The optimization algorithm to use; either `lbfgs` or `adam`; default is `adam`.
* `-num_iterations`: Default is 500.
* `-num_iterations`: Default is `500`.
* `-content_image`: Path to your input image. If no input image is specified, random noise is used instead.
* `-layer`: The specific layer you wish to use. Default is set to `mixed5a`.
* `-channel`: The specific layer channel you wish to use. Default is set to `-1` to disable specific channel selection.
* `-image_size`: A comma separated list of `<height>,<width>` to use for the output image. Default is set to `224,224`.
* `-jitter`: The amount of image jitter to use for preprocessing. Default is `32`.

**Only Required If Model Doesn't Contain Them, Options**:
* `-data_mean`: Your precalculated list of mean values that was used to train the model, if they weren't saved inside the model.
Expand Down
28 changes: 26 additions & 2 deletions data_tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ All of these scripts with the exception of `sort_images.py` can be copied to and

1. [Reduce Model Size](https://github.com/ProGamerGov/dream-creator/tree/master/data_tools#reduce-model-size)

2. [Add/Change Model Values](https://github.com/ProGamerGov/dream-creator/tree/master/data_tools#addchange-model-values)

3. [Visualization & Training Tools]()

1. [Graph Training Data](https://github.com/ProGamerGov/dream-creator/tree/master/data_tools#graph-training-data)
Expand All @@ -43,6 +45,7 @@ python calc_ms.py -data_path <training_data>

* `-not_caffe`: Enabling this flag will result in the mean and standard deviation output having a range of 0-1 instead of 0-255.

* `-use_rgb`: Enabling this flag will result in output values being in RGB format instead of BGR.

## FC Channel Contents

Expand Down Expand Up @@ -74,7 +77,7 @@ This script will try to automatically detect corrupt images that interfere with
python find_bad.py -data_path <training_data>
```

* `-delete_bad`: Enabling this flag will result in corrupt images being deleted automatically from the specified dataset.
* `-delete_bad`: Enabling this flag will result in corrupt images being deleted automatically from the specified dataset.


## Image Extractor
Expand Down Expand Up @@ -124,6 +127,27 @@ python strip_model.py -model_file <bvlc_out120>.pth -output_name stripped_models

* `-delete_branches`: If this flag is enabled, any auxiliary branches in the model will be removed.


## Add/Change Model Values

If need to add or change any of the stored model values then use this script. Any options left as `ignore` or `-1` will not be added/changed. This script can be useful for fixing bugs, adding new models, and adding missing values.

```
python edit_model.py -model_file <bvlc_out120>.pth -base_model bvlc -num_classes 10 -output_name edited_model.pth
```

* `-model_file`: Path to your pretrained GoogleNet model file.
* `-output_name`: Name of the output model. If left blank, no output model will be saved.
* `-data_mean`: Your precalculated list of mean values that was used to train the model. Default is `ignore`.
* `-data_sd`: Your precalculated list of standard deviation values that was used to train the model. Default is `ignore`.
* `-normval_format`: The format of your mean and standard deviation values; one of `bgr`, `rgb`, `ignore`. Default is `ignore`.
* `-has_branches`: Whether or not the model has branches; one of `true`, `false`, `ignore`. Default is `ignore`.
* `-base_model`: The base model used to create your model; one of `bvlc`, `p365`, `5h`, `ignore`. Default is `ignore`.
* `-num_classes`: Set the number of model classes. Default is set to `-1` to ignore.
* `-model_epoch`: Set the model epoch. Default is set to `-1` to ignore.
* `-reverse_normvals`: If this flag is enabled, mean and standard deviation values added to the model and stored in the model will be reversed. In essence BGR values are converted to RGB and vice versa.
* `-print_vals`: If this flag is enabled, all stored model values from the loaded model will be printed.

---

# Comparison of Results
Expand All @@ -145,7 +169,7 @@ python resize_data.py -csv_file train_acc.txt

## Image Grid Creator

This script will put images created by `vis_fc.py` into a grid for easy comparisons and analysis
This script will put images created by `vis_multi.py` into a grid for easy comparisons and analysis

```
python make_grid.py -input_path <image_dir>
Expand Down
6 changes: 5 additions & 1 deletion data_tools/calc_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def main():
parser.add_argument("-data_path", help="Path to your dataset", type=str, default='')
parser.add_argument("-batch_size", type=int, default=10)
parser.add_argument("-not_caffe", action='store_true')
parser.add_argument("-use_rgb", action='store_true')
params = parser.parse_args()
main_calc(params)

Expand All @@ -21,7 +22,10 @@ def main_calc(params):

if not params.not_caffe:
range_change = transforms.Compose([transforms.Lambda(lambda x: x*255)])
transform_list = transform_list + [range_change]
transform_list += [range_change]
if not params.use_rgb:
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])])
transform_list += [rgb2bgr]

dataset = torchvision.datasets.ImageFolder(params.data_path, transform=transforms.Compose(transform_list))
loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, num_workers=0, shuffle=False)
Expand Down
118 changes: 118 additions & 0 deletions data_tools/edit_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import argparse
import torch
import copy


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-model_file", type=str, default='')
parser.add_argument("-num_classes", type=int, default=-1)
parser.add_argument("-epoch", type=int, default=-1)
parser.add_argument("-base_model", choices=['bvlc', 'p365', '5h', 'ignore'], default='ignore')
parser.add_argument("-data_mean", type=str, default='ignore')
parser.add_argument("-data_sd", type=str, default='ignore')
parser.add_argument("-normval_format", choices=['bgr', 'rgb', 'ignore'], default='ignore')
parser.add_argument("-has_branches", choices=['true', 'false', 'ignore'], default='ignore')
parser.add_argument("-reverse_normvals", action='store_true')
parser.add_argument("-print_vals", action='store_true')
parser.add_argument("-output_name", type=str, default='')
params = parser.parse_args()
main_func(params)


def main_func(params):
checkpoint = torch.load(params.model_file, map_location='cpu')
save_model = copy.deepcopy(checkpoint)

if params.print_vals:
print_model_vals(save_model)

if params.num_classes > -1:
save_model['num_classes'] = params.num_classes

if params.base_model != 'ignore':
save_model['base_model'] = params.base_model

if params.has_branches != 'ignore':
has_branches = True if params.has_branches == 'true' else False
save_model['has_branches'] = has_branches

if params.epoch != -1:
save_model['epoch'] = params.epoch

if params.data_mean != 'ignore' or params.data_sd != 'ignore' or params.normval_format != 'ignore':
try:
norm_vals = save_model['normalize_params']
if params.data_mean != 'ignore':
norm_vals[0] = [float(m) for m in params.data_mean.split(',')]
if params.data_sd != 'ignore':
norm_vals[1] = [float(s) for s in params.data_sd.split(',')]
if params.normval_format != 'ignore':
try:
norm_vals[2] = params.normval_format
except:
norm_vals += [params.normval_format] # Add to legacy models
save_model['normalize_params'] = norm_vals

except:
assert params.data_mean != 'ignore', "'-data_mean' is required"
assert params.data_sd != 'ignore', "'-data_sd' is required"
assert params.normval_format != 'ignore', "'-normval_format' is required"
save_model['normalize_params'] = [params.data_mean, params.data_sd, params.normval_format]

if params.reverse_normvals:
norm_vals = save_model['normalize_params']
norm_vals[0].reverse()
norm_vals[1].reverse()
save_model['normalize_params'] = norm_vals

if params.output_name != '':
torch.save(save_model, save_name)


def print_model_vals(model):
print('Model Values')

try:
print(' Num classes:', model['num_classes'])
except:
pass
try:
print(' Base model:', model['base_model'])
except:
pass
try:
print(' Model epoch:', model['epoch'])
except:
pass
try:
print(' Has branches:', model['has_branches'])
except:
pass
try:
print(' Norm value format', model['normalize_params'][2])
except:
pass
try:
print(' Mean values:', model['normalize_params'][0])
except:
pass
try:
print(' Standard deviation values:', model['normalize_params'][1])
except:
pass
try:
test = model['optimizer_state_dict']
print(' Contains saved optimizer state')
except:
pass
try:
test = model['lrscheduler_state_dict']
print(' Contains saved learning rate scheduler state')
except:
pass



if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion data_tools/make_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def main_test(params):
if params.epoch != -1:
image_list = filter_images(image_list, 'e' + str(params.epoch).zfill(3))
if params.channel != -1:
image_list = filter_images(image_list, 'c' + str(params.channel).zfill(2))
image_list = filter_images(image_list, 'c' + str(params.channel).zfill(4))

if not params.disable_natsort:
image_list.sort(key=n_keys)
Expand Down
2 changes: 0 additions & 2 deletions data_tools/sort_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def main_func(params):
params.data_sd = norm_vals[1]
else:
params.data_mean = [float(m) for m in params.data_mean.split(',')]
params.data_mean.reverse() # RGB to BGR
params.data_sd = [float(m) for m in params.data_sd.split(',')]
params.data_sd.reverse() # RGB to BGR

cnn = cnn.to(params.use_device).eval()
for param in cnn.parameters():
Expand Down
15 changes: 6 additions & 9 deletions train_googlenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,13 @@ def main_func(params):

if params.seed > -1:
set_seed(params.seed)
rnd_generator = torch.Generator(device='cpu') if params.seed > -1 else None

# Setup image training data
if params.balance_classes:
training_data, num_classes, class_weights = load_dataset(data_path=params.data_path, val_percent=params.val_percent, batch_size=params.batch_size, \
input_mean=params.data_mean, input_sd=params.data_sd, use_caffe=not params.not_caffe, \
train_workers=params.train_workers, val_workers=params.val_workers, balance_weights=params.balance_classes)
else:
training_data, num_classes = load_dataset(data_path=params.data_path, val_percent=params.val_percent, batch_size=params.batch_size, \
input_mean=params.data_mean, input_sd=params.data_sd, use_caffe=not params.not_caffe, \
train_workers=params.train_workers, val_workers=params.val_workers, balance_weights=False)
training_data, num_classes, class_weights = load_dataset(data_path=params.data_path, val_percent=params.val_percent, batch_size=params.batch_size, \
input_mean=params.data_mean, input_sd=params.data_sd, use_caffe=not params.not_caffe, \
train_workers=params.train_workers, val_workers=params.val_workers, balance_weights=params.balance_classes, \
rnd_generator=rnd_generator)


# Setup model definition
Expand Down Expand Up @@ -147,7 +144,7 @@ def main_func(params):
torch.backends.cudnn.enabled = True


save_info = [[params.data_mean, params.data_sd], num_classes, has_branches, base_model]
save_info = [[params.data_mean, params.data_sd, 'BGR'], num_classes, has_branches, base_model]

# Train model
train_model(model=cnn, dataloaders=training_data, criterion=criterion, optimizer=optimizer, lrscheduler=lrscheduler, \
Expand Down
4 changes: 3 additions & 1 deletion utils/inceptionv1_caffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(self, out_features=1000, load_branches=True, mode='bvlc'):
super(InceptionV1_Caffe, self).__init__()
self.mode = mode
self.use_branches = load_branches
self.use_fc = True

if self.mode == 'p365' or self.mode == 'bvlc':
lrn_vals = (5, 9.999999747378752e-05, 0.75, 1)
Expand Down Expand Up @@ -183,7 +184,8 @@ def forward(self, x):
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.drop(x)
x = self.fc(x)
if self.use_fc:
x = self.fc(x)

if not self.use_branches:
return x
Expand Down
16 changes: 11 additions & 5 deletions utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def split_percent(n, pctg=0.2):
# Create training and validation images from single set of images
def load_dataset(data_path='test_data', val_percent=0.2, batch_size=1, input_size=(224,224), \
input_mean=[0.485, 0.456, 0.406], input_sd=[0.229, 0.224, 0.225], use_caffe=False, \
train_workers=25, val_workers=5, balance_weights=False):
train_workers=25, val_workers=5, balance_weights=False, rnd_generator=None):

num_classes = sum(os.path.isdir(os.path.join(data_path, i)) for i in os.listdir(data_path))

Expand Down Expand Up @@ -52,9 +52,13 @@ def load_dataset(data_path='test_data', val_percent=0.2, batch_size=1, input_siz
get_fc_channel_classes(data_path)
if val_percent > 0:
lengths = split_percent(len(full_dataset), val_percent)
t_data, v_data = torch.utils.data.random_split(full_dataset, lengths)
if rnd_generator == None:
t_data, v_data = torch.utils.data.random_split(full_dataset, lengths)
else:
t_data, v_data = torch.utils.data.random_split(full_dataset, lengths, generator=rnd_generator)
else:
t_data, v_data = copy.deepcopy(full_dataset), copy.deepcopy(full_dataset)
t_data, v_data = torch.utils.data.Subset(copy.deepcopy(full_dataset), range(0, len(full_dataset))), \
torch.utils.data.Subset(copy.deepcopy(full_dataset), range(0, len(full_dataset)))

# Use separate transforms for training and validation data
t_data = copy.deepcopy(t_data)
Expand All @@ -66,21 +70,23 @@ def load_dataset(data_path='test_data', val_percent=0.2, batch_size=1, input_siz
batch_size=batch_size,
num_workers=train_workers,
shuffle=True,
generator=rnd_generator,
)
val_loader = torch.utils.data.DataLoader(
v_data,
batch_size=batch_size,
num_workers=val_workers,
shuffle=True,
generator=rnd_generator,
)

if balance_weights:
train_class_counts = count_classes(train_loader.dataset)
train_weights = [1 / train_class_counts[class_id] for class_id in range(num_classes)]
train_weights = torch.FloatTensor(train_weights)
return {'train': train_loader, 'val': val_loader}, num_classes, train_weights
else:
return {'train': train_loader, 'val': val_loader}, num_classes
train_weights = None
return {'train': train_loader, 'val': val_loader}, num_classes, train_weights


# Get the number of images in each class in a dataset
Expand Down
7 changes: 7 additions & 0 deletions utils/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def load_model(model_file, num_classes=120, has_branches=True, mode='bvlc'):
model_keys = checkpoint.keys()
cnn.load_state_dict(checkpoint['model_state_dict'])
else:
base_name = os.path.basename(model_file)
if base_name.lower() == 'pt_bvlc.pth' or base_name.lower() == 'pt_places365.pth':
cnn.use_fc = False
if base_name.lower() == 'pt_bvlc.pth' or base_name.lower() == 'pt_inception5h.pth':
norm_vals = [[103.939,116.779,123.68], [1,1,1], 'BGR']
elif base_name.lower() == 'pt_places365.pth':
norm_vals = [[104.051,112.514,116.676], [1,1,1], 'BGR']
cnn.load_state_dict(checkpoint)
return cnn, norm_vals, num_classes

Expand Down
Loading

0 comments on commit b912f21

Please sign in to comment.