Skip to content

Commit

Permalink
Update model links, add content images
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed May 30, 2020
1 parent fd02641 commit 0d80b83
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 12 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ I've included automatic, pretrained models and MS COCO dataset, download script

## Examples

Here are some examples with the pretrained models:
Here are some examples with the [4 pretrained models](r'https://drive.google.com/uc?export=download&id=1_Ae_W0q9qN3JtH4tMnSQvuwLiCsWWSTO') (automatic download enabled - look at [usage section](#usage)):

<p align="center">
<img src="data/style-images/mosaic_crop_resized_230.jpg" width="230px">
Expand All @@ -39,6 +39,8 @@ Here are some examples with the pretrained models:
<img src="data/examples/candy_model/figures_width_500_model_candy_resize_230.jpg" width="342px">
</p>

*Note:* keep in mind that I still need to improve these models 3 of these only saw 33k images from MS COCO.

## Setup

1. Run `conda env create` from project directory.
Expand Down
Binary file added data/content-images/lion.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/content-images/taj_mahal.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions stylization_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def stylize_static_image(inference_config):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

content_img_path = os.path.join(inference_config['content_images_path'], inference_config['content_img_name'])
content_image = utils.prepare_img(content_img_path, inference_config['img_height'], device)
content_image = utils.prepare_img(content_img_path, inference_config['img_width'], device)

# load the weights and set the model to evaluation mode
stylization_model = TransformerNet().to(device)
Expand Down Expand Up @@ -43,9 +43,9 @@ def stylize_static_image(inference_config):
# Modifiable args - feel free to play with these
#
parser = argparse.ArgumentParser()
parser.add_argument("--content_img_name", type=str, help="content image to stylize", default='figures.jpg')
parser.add_argument("--img_height", type=int, help="resize content image to this height", default=None)
parser.add_argument("--model_name", type=str, help="model binary to use for stylization", default='starry_v2.pth')
parser.add_argument("--content_img_name", type=str, help="content image to stylize", default='taj_mahal.jpg')
parser.add_argument("--img_width", type=int, help="resize content image to this width", default=500)
parser.add_argument("--model_name", type=str, help="model binary to use for stylization", default='starry_v3.pth')
args = parser.parse_args()

# Wrapping inference configuration into a dictionary
Expand Down
2 changes: 1 addition & 1 deletion training_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def train(training_config):
#
parser = argparse.ArgumentParser()
# training related
parser.add_argument("--style_img_name", type=str, help="style image name that will be used for training", default='mosaic.jpg')
parser.add_argument("--style_img_name", type=str, help="style image name that will be used for training", default='edtaonisl.jpg')
parser.add_argument("--content_weight", type=float, help="weight factor for content loss", default=1e0) # you don't need to change this one just play with style loss
parser.add_argument("--style_weight", type=float, help="weight factor for style loss", default=4e5)
parser.add_argument("--tv_weight", type=float, help="weight factor for total variation loss", default=0)
Expand Down
4 changes: 2 additions & 2 deletions utils/resource_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
# If the link is broken you can download the MS COCO 2014 dataset manually from http://cocodataset.org/#download
MS_COCO_2014_TRAIN_DATASET_PATH = r'http://images.cocodataset.org/zips/train2014.zip' # ~13 GB after unzipping

# todo: update the link after I train new models using this repo
PRETRAINED_MODELS_PATH = r'https://drive.google.com/uc?export=download&id=1_Ae_W0q9qN3JtH4tMnSQvuwLiCsWWSTO'
PRETRAINED_MODELS_PATH = r'https://www.dropbox.com/s/fb39gscd1b42px1/pretrained_models.zip?dl=1' # r'https://drive.google.com/uc?export=download&id=1_Ae_W0q9qN3JtH4tMnSQvuwLiCsWWSTO'

DOWNLOAD_DICT = {
'pretrained_models': PRETRAINED_MODELS_PATH,
Expand All @@ -28,6 +27,7 @@

# step1: download the resource to local filesystem
remote_resource_path = DOWNLOAD_DICT[args.resource]
print(f'Downloading from {remote_resource_path}')
resource_tmp_path = args.resource + '.zip'
download_url_to_file(remote_resource_path, resource_tmp_path)

Expand Down
10 changes: 6 additions & 4 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def load_image(img_path, target_shape=None):
img = cv.imread(img_path)[:, :, ::-1] # [:, :, ::-1] converts BGR (opencv format...) into RGB

if target_shape is not None: # resize section
if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the height
if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the width
current_height, current_width = img.shape[:2]
new_height = target_shape
new_width = int(current_width * (new_height / current_height))
new_width = target_shape
new_height = int(current_height * (new_width / current_width))
img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC)
else: # set both dimensions to target shape
img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC)
Expand Down Expand Up @@ -67,7 +67,9 @@ def save_and_maybe_display_image(inference_config, dump_img, should_display=Fals
assert isinstance(dump_img, np.ndarray), f'Expected numpy array got {type(dump_img)}.'

dump_img = post_process_image(dump_img)
dump_img_name = inference_config['content_img_name'].split('.')[0] + '_' + str(inference_config['img_height']) + '_' + inference_config['model_name'] + '.jpg'
if inference_config['img_height'] is None:
inference_config['img_height'] = dump_img.shape[0]
dump_img_name = inference_config['content_img_name'].split('.')[0] + '_width_' + str(inference_config['img_height']) + '_model_' + inference_config['model_name'].split('.')[0] + '.jpg'
cv.imwrite(os.path.join(inference_config['output_images_path'], dump_img_name), dump_img[:, :, ::-1]) # ::-1 because opencv expects BGR (and not RGB) format...
print(f'Saved image to {inference_config["output_images_path"]}.')

Expand Down

0 comments on commit 0d80b83

Please sign in to comment.