forked from Pay20Y/GCAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 3353c52
Showing
22 changed files
with
3,791 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Gaussian Constrained Attention Network for Scene Text Recognition | ||
## Introduction | ||
Implementation of the paper "Gaussian Constrained Attention Network for Scene Text Recognition" (Under Review) | ||
## How to use | ||
### Install | ||
``` | ||
pip3 install -r requirements.txt | ||
``` | ||
### Train | ||
* <b> Data prepare</b> | ||
LMDB format is suggested. refer [here](https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py) to generate data in LMDB format. | ||
|
||
* <b> Run</b> | ||
|
||
``` | ||
python3 train.py --checkpoints /path/to/save/checkpoints --train_data_dir /path/to/your/train/LMDB/data/dir --test_data_dir /path/to/your/validation/LMDB/data/dir -g "0" --train_batch_size 128 --val_batch_size 128 --aug True --att_loss_type "l1" --att_loss_weight 10.0 | ||
``` | ||
More hyper-parameters please refer to [config.py](https://github.com/Pay20Y/GCAN/blob/master/config.py) | ||
|
||
### Test | ||
|
||
* Download the pretrained model from [BaiduYun](https://pan.baidu.com/s/1hY374pvtDtgeBUPsG7R5ew) (key:w14k) | ||
* Download the benchmark datasets from [BaiduYun](https://pan.baidu.com/s/1Z4aI1_B7Qwg9kVECK0ucrQ) (key: nphk) shared by clovaai in this [repo](https://github.com/clovaai/deep-text-recognition-benchmark) | ||
|
||
``` | ||
python3 test.py --checkpoints /path/to/the/pretrained/model --test_data_dir /path/to/the/evaluation/benchmark/lmdb/dir -g "0" | ||
``` | ||
|
||
## Experiments on benchmarks | ||
|
||
| IIIT5K | IC13 | IC15 | SVT | SVTP | CUTE | | ||
|:-------:|:-----:|:------:|:----:|:-----:|:------:| | ||
| 94.4 | 93.3 | 77.1 | 90.1 | 81.2 | 85.6 | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description="Softmax loss classification") | ||
|
||
# Data | ||
parser.add_argument('--train_data_dir', nargs='+', type=str, metavar='PATH', | ||
default=[None]) | ||
parser.add_argument('--train_data_gt', nargs='+', type=str, metavar='PATH', | ||
default=[None]) | ||
parser.add_argument('--test_data_dir', type=str, metavar='PATH', | ||
default=None) | ||
parser.add_argument('--test_data_gt', type=str, metavar='PATH', | ||
default=None) | ||
parser.add_argument('-b', '--train_batch_size', type=int, default=128) | ||
parser.add_argument('-v', '--val_batch_size', type=int, default=128) | ||
parser.add_argument('-j', '--workers', type=int, default=2) | ||
parser.add_argument('-g', '--gpus', type=str, default='1') | ||
parser.add_argument('--height', type=int, default=48, help="input height") | ||
parser.add_argument('--width', type=int, default=160, help="input width") | ||
parser.add_argument('--aug', type=bool, default=False, help="using data augmentation or not") | ||
parser.add_argument('--keep_ratio', action='store_true', default=True, | ||
help='length fixed or lenghth variable.') | ||
parser.add_argument('--voc_type', type=str, default='ALLCASES_SYMBOLS', | ||
choices=['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']) | ||
parser.add_argument('--num_train', type=int, default=-1) | ||
parser.add_argument('--num_test', type=int, default=-1) | ||
|
||
# Model | ||
parser.add_argument('--max_len', type=int, default=30) | ||
parser.add_argument('--encoder_sdim', type=int, default=512, | ||
help="the dim of hidden layer in encoder.") | ||
parser.add_argument('--encoder_layers', type=int, default=2, | ||
help="the num of layers in encoder lstm.") | ||
parser.add_argument('--decoder_sdim', type=int, default=512, | ||
help="the dim of hidden layer in decoder.") | ||
parser.add_argument('--decoder_layers', type=int, default=2, | ||
help="the num of layers in decoder lstm.") | ||
parser.add_argument('--decoder_edim', type=int, default=512, | ||
help="the dim of embedding layer in decoder.") | ||
parser.add_argument("--att_loss_type", type=str, default='l1') | ||
parser.add_argument("--att_loss_weight", type=float, default=10.) | ||
|
||
# Optimizer | ||
parser.add_argument('--lr', type=float, default=0.0008, help="learning rate of new parameters, for pretrained ") | ||
parser.add_argument('--weight_decay', type=float, default=0.9) # the model maybe under-fitting, 0.0 gives much better results. | ||
parser.add_argument('--decay_iter', type=int, default=100000) | ||
parser.add_argument('--decay_bound', nargs='+', type=int, default=[180000, 240000]) # [120000, 150000] (800K) | ||
# parser.add_argument('--lr_stage', nargs='+', type=float, default=[0.001, 0.1, 0.01]) | ||
parser.add_argument('--lr_stage', nargs='+', type=float, default=[0.0001, 0.00001, 0.000001]) | ||
parser.add_argument('--decay_end', type=float, default=0.00001) | ||
parser.add_argument('--grad_clip', type=float, default=-1.0) | ||
parser.add_argument('--iters', type=int, default=3000000) | ||
parser.add_argument('--decode_type', type=str, default='greed') | ||
|
||
parser.add_argument('--resume', type=bool, default=False) | ||
parser.add_argument('--pretrained', type=str, default='', metavar='PATH') | ||
parser.add_argument('--checkpoints', type=str, default='./checkpoints/', metavar='PATH') | ||
parser.add_argument('--log_iter', type=int, default=100) | ||
parser.add_argument('--summary_iter', type=int, default=1000) | ||
parser.add_argument('--eval_iter', type=int, default=2000) | ||
parser.add_argument('--save_iter', type=int, default=2000) | ||
parser.add_argument('--vis_dir', type=str, metavar='PATH') | ||
|
||
def get_args(sys_args): | ||
global_args = parser.parse_args(sys_args) | ||
return global_args |
Empty file.
Oops, something went wrong.