|
| 1 | +import os |
| 2 | +import time |
| 3 | +import torch |
| 4 | +import argparse |
| 5 | +from torch import nn |
| 6 | +from torch.utils import data |
| 7 | +from torchvision import transforms |
| 8 | +from tools.utils import * |
| 9 | +import tools.model as models |
| 10 | +from dataset.scene_dataset import * |
| 11 | + |
| 12 | + |
| 13 | +def main(args): |
| 14 | + if args.dataID==1: |
| 15 | + DataName = 'UCM' |
| 16 | + num_classes = 21 |
| 17 | + classname = ('agricultural','airplane','baseballdiamond', |
| 18 | + 'beach','buildings','chaparral', |
| 19 | + 'denseresidential','forest','freeway', |
| 20 | + 'golfcourse','harbor','intersection', |
| 21 | + 'mediumresidential','mobilehomepark','overpass', |
| 22 | + 'parkinglot','river','runway', |
| 23 | + 'sparseresidential','storagetanks','tenniscourt') |
| 24 | + |
| 25 | + elif args.dataID==2: |
| 26 | + DataName = 'AID' |
| 27 | + num_classes = 30 |
| 28 | + classname = ('airport','bareland','baseballfield', |
| 29 | + 'beach','bridge','center', |
| 30 | + 'church','commercial','denseresidential', |
| 31 | + 'desert','farmland','forest', |
| 32 | + 'industrial','meadow','mediumresidential', |
| 33 | + 'mountain','parking','park', |
| 34 | + 'playground','pond','port', |
| 35 | + 'railwaystation','resort','river', |
| 36 | + 'school','sparseresidential','square', |
| 37 | + 'stadium','storagetanks','viaduct') |
| 38 | + |
| 39 | + |
| 40 | + print_per_batches = args.print_per_batches |
| 41 | + save_path_prefix = args.save_path_prefix+DataName+'/Pretrain/'+args.network+'/' |
| 42 | + |
| 43 | + if os.path.exists(save_path_prefix)==False: |
| 44 | + os.makedirs(save_path_prefix) |
| 45 | + |
| 46 | + composed_transforms = transforms.Compose([ |
| 47 | + transforms.Resize(size=(args.crop_size,args.crop_size)), |
| 48 | + transforms.ToTensor(), |
| 49 | + transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) |
| 50 | + |
| 51 | + train_loader = data.DataLoader( |
| 52 | + scene_dataset(root_dir=args.root_dir,pathfile='./dataset/'+DataName+'_train.txt', transform=composed_transforms), |
| 53 | + batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) |
| 54 | + |
| 55 | + val_loader = data.DataLoader( |
| 56 | + scene_dataset(root_dir=args.root_dir,pathfile='./dataset/'+DataName+'_test.txt', transform=composed_transforms), |
| 57 | + batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) |
| 58 | + |
| 59 | + ###################Network Definition################### |
| 60 | + if args.network=='alexnet': |
| 61 | + Model = models.alexnet(pretrained=True) |
| 62 | + Model.classifier._modules['6'] = nn.Linear(4096, num_classes) |
| 63 | + elif args.network=='vgg11': |
| 64 | + Model = models.vgg11(pretrained=True) |
| 65 | + Model.classifier._modules['6'] = nn.Linear(4096, num_classes) |
| 66 | + elif args.network=='vgg16': |
| 67 | + Model = models.vgg16(pretrained=True) |
| 68 | + Model.classifier._modules['6'] = nn.Linear(4096, num_classes) |
| 69 | + elif args.network=='vgg19': |
| 70 | + Model = models.vgg19(pretrained=True) |
| 71 | + Model.classifier._modules['6'] = nn.Linear(4096, num_classes) |
| 72 | + elif args.network=='inception': |
| 73 | + Model = models.inception_v3(pretrained=True, aux_logits=False) |
| 74 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 75 | + elif args.network=='resnet18': |
| 76 | + Model = models.resnet18(pretrained=True) |
| 77 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 78 | + elif args.network=='resnet50': |
| 79 | + Model = models.resnet50(pretrained=True) |
| 80 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 81 | + elif args.network=='resnet101': |
| 82 | + Model = models.resnet101(pretrained=True) |
| 83 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 84 | + elif args.network=='resnext50_32x4d': |
| 85 | + Model = models.resnext50_32x4d(pretrained=True) |
| 86 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 87 | + elif args.network=='resnext101_32x8d': |
| 88 | + Model = models.resnext101_32x8d(pretrained=True) |
| 89 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 90 | + elif args.network=='densenet121': |
| 91 | + Model = models.densenet121(pretrained=True) |
| 92 | + Model.classifier = nn.Linear(1024, num_classes) |
| 93 | + elif args.network=='densenet169': |
| 94 | + Model = models.densenet169(pretrained=True) |
| 95 | + Model.classifier = nn.Linear(1664, num_classes) |
| 96 | + elif args.network=='densenet201': |
| 97 | + Model = models.densenet201(pretrained=True) |
| 98 | + Model.classifier = nn.Linear(1920, num_classes) |
| 99 | + elif args.network=='regnet_x_400mf': |
| 100 | + Model = models.regnet_x_400mf(pretrained=True) |
| 101 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 102 | + elif args.network=='regnet_x_8gf': |
| 103 | + Model = models.regnet_x_8gf(pretrained=True) |
| 104 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 105 | + elif args.network=='regnet_x_16gf': |
| 106 | + Model = models.regnet_x_16gf(pretrained=True) |
| 107 | + Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes) |
| 108 | + |
| 109 | + Model = torch.nn.DataParallel(Model).cuda() |
| 110 | + Model_optimizer = torch.optim.Adam(Model.parameters(),lr=args.lr) |
| 111 | + num_batches = len(train_loader) |
| 112 | + |
| 113 | + cls_loss = torch.nn.CrossEntropyLoss() |
| 114 | + num_steps = args.num_epochs*num_batches |
| 115 | + hist = np.zeros((num_steps,3)) |
| 116 | + index_i = -1 |
| 117 | + |
| 118 | + for epoch in range(args.num_epochs): |
| 119 | + for batch_index, src_data in enumerate(train_loader): |
| 120 | + index_i += 1 |
| 121 | + |
| 122 | + tem_time = time.time() |
| 123 | + Model.train() |
| 124 | + Model_optimizer.zero_grad() |
| 125 | + |
| 126 | + X_train, Y_train, _ = src_data |
| 127 | + X_train = X_train.cuda() |
| 128 | + Y_train = Y_train.cuda() |
| 129 | + |
| 130 | + _,output = Model(X_train) |
| 131 | + |
| 132 | + # CE Loss |
| 133 | + _, src_prd_label = torch.max(output, 1) |
| 134 | + cls_loss_value = cls_loss(output, Y_train) |
| 135 | + cls_loss_value.backward() |
| 136 | + |
| 137 | + Model_optimizer.step() |
| 138 | + |
| 139 | + hist[index_i,0] = time.time()-tem_time |
| 140 | + hist[index_i,1] = cls_loss_value.item() |
| 141 | + hist[index_i,2] = torch.mean((src_prd_label == Y_train).float()).item() |
| 142 | + |
| 143 | + tem_time = time.time() |
| 144 | + if (batch_index+1) % print_per_batches == 0: |
| 145 | + print('Epoch %d/%d: %d/%d Time: %.2f cls_loss = %.3f acc = %.3f \n'\ |
| 146 | + %(epoch+1, args.num_epochs,batch_index+1,num_batches, |
| 147 | + np.mean(hist[index_i-print_per_batches+1:index_i+1,0]), |
| 148 | + np.mean(hist[index_i-print_per_batches+1:index_i+1,1]), |
| 149 | + np.mean(hist[index_i-print_per_batches+1:index_i+1,2]))) |
| 150 | + |
| 151 | + |
| 152 | + |
| 153 | + OA_new,_ = test_acc(Model,classname, val_loader, epoch+1,num_classes,print_per_batches=10) |
| 154 | + |
| 155 | + model_name = 'epoch_'+str(epoch+1)+'_OA_'+repr(int(OA_new*10000))+'.pth' |
| 156 | + |
| 157 | + print('Save Model') |
| 158 | + torch.save(Model.state_dict(), os.path.join(save_path_prefix, model_name)) |
| 159 | + |
| 160 | + |
| 161 | + |
| 162 | +if __name__ == '__main__': |
| 163 | + parser = argparse.ArgumentParser() |
| 164 | + |
| 165 | + parser.add_argument('--dataID', type=int, default=1) |
| 166 | + parser.add_argument('--network', type=str, default='resnet18', |
| 167 | + help='alexnet,vgg11,vgg16,vgg19,inception,resnet18,resnet50,resnet101,resnext50_32x4d,resnext101_32x8d,densenet121,densenet169,densenet201,regnet_x_400mf,regnet_x_8gf,regnet_x_16gf') |
| 168 | + parser.add_argument('--save_path_prefix', type=str, default='./') |
| 169 | + parser.add_argument('--root_dir', type=str, default='/iarai/home/yonghao.xu/Data/',help='dataset path.') |
| 170 | + parser.add_argument('--train_batch_size', type=int, default=64) |
| 171 | + parser.add_argument('--val_batch_size', type=int, default=64) |
| 172 | + parser.add_argument('--num_workers', type=int, default=1) |
| 173 | + parser.add_argument('--lr', type=float, default=1e-4) |
| 174 | + parser.add_argument('--crop_size', type=int, default=256) |
| 175 | + parser.add_argument('--num_epochs', type=int, default=10) |
| 176 | + parser.add_argument('--print_per_batches', type=int, default=5) |
| 177 | + |
| 178 | + main(parser.parse_args()) |
0 commit comments