-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_linprob.sh
89 lines (83 loc) · 2.41 KB
/
train_linprob.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# ------------------- Args setting -------------------
MODEL=$1
MODEL_T=$2
BATCH_SIZE=$3
DATASET=$4
DATASET_ROOT=$5
WORLD_SIZE=$6
RESUME=$7
# ------------------- Training setting -------------------
## Epoch
MAX_EPOCH=100
WP_EPOCH=5
EVAL_EPOCH=5
## Optimizer
BASE_LR=0.1
MIN_LR=0.0
WEIGHT_DECAY=0.05
# ------------------- Dataset setting -------------------
DATASET="cifar10"
if [[ $DATASET == "cifar10" || $DATASET == "cifar100" ]]; then
# Data root
ROOT="none"
# Image config
IMG_SIZE=32
PATCH_SIZE=2
elif [[ $DATASET == "imagenet_1k" || $DATASET == "imagenet_22k" ]]; then
# Data root
ROOT="path/to/imagenet"
# Image config
IMG_SIZE=224
PATCH_SIZE=16
elif [[ $DATASET == "custom" ]]; then
# Data root
ROOT="path/to/custom"
# Image config
IMG_SIZE=224
PATCH_SIZE=16
else
echo "Unknown dataset!!"
exit 1
fi
# ------------------- Training pipeline -------------------
WORLD_SIZE=1
if [ $WORLD_SIZE == 1 ]; then
python main_linprobe.py \
--cuda \
--root ${ROOT} \
--dataset ${DATASET} \
--model ${MODEL} \
--batch_size ${BATCH_SIZE} \
--img_size ${IMG_SIZE} \
--patch_size ${PATCH_SIZE} \
--max_epoch ${MAX_EPOCH} \
--wp_epoch ${WP_EPOCH} \
--eval_epoch ${EVAL_EPOCH} \
--base_lr ${BASE_LR} \
--min_lr ${MIN_LR} \
--weight_decay ${WEIGHT_DECAY} \
--resume ${RESUME} \
--pretrained ${PRETRAINED_MODEL}
elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port 1668 main_linprobe.py \
--cuda \
-dist \
--root ${ROOT} \
--dataset ${DATASET} \
--model ${MODEL} \
--batch_size ${BATCH_SIZE} \
--img_size ${IMG_SIZE} \
--patch_size ${PATCH_SIZE} \
--max_epoch ${MAX_EPOCH} \
--wp_epoch ${WP_EPOCH} \
--eval_epoch ${EVAL_EPOCH} \
--base_lr ${BASE_LR} \
--min_lr ${MIN_LR} \
--weight_decay ${WEIGHT_DECAY} \
--resume ${RESUME} \
--pretrained ${PRETRAINED_MODEL}
else
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
multi-card training mode, which is currently unsupported."
exit 1
fi