Skip to content

Commit 07e503c

Browse files
authored
Add files via upload
1 parent 49f1262 commit 07e503c

22 files changed

+13077
-0
lines changed

dataset/AID_mixcut_sample.png

598 KB
Loading

dataset/AID_mixup_sample.png

513 KB
Loading

dataset/AID_test.txt

+5,000
Large diffs are not rendered by default.

dataset/AID_train.txt

+5,000
Large diffs are not rendered by default.

dataset/UCM_mixcut_sample.png

120 KB
Loading

dataset/UCM_mixup_sample.png

103 KB
Loading

dataset/UCM_test.txt

+1,050
Large diffs are not rendered by default.

dataset/UCM_train.txt

+1,050
Large diffs are not rendered by default.

dataset/scene_dataset.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from torch.utils.data import Dataset
2+
from PIL import Image
3+
4+
def default_loader(path):
5+
return Image.open(path).convert('RGB')
6+
7+
class scene_dataset(Dataset):
8+
def __init__(self, root_dir, pathfile, transform=None, loader=default_loader, mode='clean'):
9+
pf = open(pathfile, 'r')
10+
imgs = []
11+
if mode=='clean':
12+
for line in pf:
13+
line = line.rstrip('\n')
14+
words = line.split()
15+
name = words[0].split('/')[-1].split('.')[0]
16+
imgs.append((root_dir+words[0],int(words[1]),name))
17+
elif mode=='adv':
18+
for line in pf:
19+
line = line.rstrip('\n')
20+
words = line.split()
21+
name = words[0].split('/')[-1].split('.')[0]
22+
imgs.append((root_dir+words[0].split('/')[-1].split('.')[0]+'_adv.png',int(words[1]),name))
23+
24+
self.imgs = imgs
25+
self.transform = transform
26+
self.loader = loader
27+
pf.close()
28+
29+
def __getitem__(self, index):
30+
fn, label, name = self.imgs[index]
31+
img = self.loader(fn)
32+
if self.transform is not None:
33+
img = self.transform(img)
34+
return img,label,name
35+
36+
def __len__(self):
37+
return len(self.imgs)
38+

pretrain_cls.py

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import os
2+
import time
3+
import torch
4+
import argparse
5+
from torch import nn
6+
from torch.utils import data
7+
from torchvision import transforms
8+
from tools.utils import *
9+
import tools.model as models
10+
from dataset.scene_dataset import *
11+
12+
13+
def main(args):
14+
if args.dataID==1:
15+
DataName = 'UCM'
16+
num_classes = 21
17+
classname = ('agricultural','airplane','baseballdiamond',
18+
'beach','buildings','chaparral',
19+
'denseresidential','forest','freeway',
20+
'golfcourse','harbor','intersection',
21+
'mediumresidential','mobilehomepark','overpass',
22+
'parkinglot','river','runway',
23+
'sparseresidential','storagetanks','tenniscourt')
24+
25+
elif args.dataID==2:
26+
DataName = 'AID'
27+
num_classes = 30
28+
classname = ('airport','bareland','baseballfield',
29+
'beach','bridge','center',
30+
'church','commercial','denseresidential',
31+
'desert','farmland','forest',
32+
'industrial','meadow','mediumresidential',
33+
'mountain','parking','park',
34+
'playground','pond','port',
35+
'railwaystation','resort','river',
36+
'school','sparseresidential','square',
37+
'stadium','storagetanks','viaduct')
38+
39+
40+
print_per_batches = args.print_per_batches
41+
save_path_prefix = args.save_path_prefix+DataName+'/Pretrain/'+args.network+'/'
42+
43+
if os.path.exists(save_path_prefix)==False:
44+
os.makedirs(save_path_prefix)
45+
46+
composed_transforms = transforms.Compose([
47+
transforms.Resize(size=(args.crop_size,args.crop_size)),
48+
transforms.ToTensor(),
49+
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])
50+
51+
train_loader = data.DataLoader(
52+
scene_dataset(root_dir=args.root_dir,pathfile='./dataset/'+DataName+'_train.txt', transform=composed_transforms),
53+
batch_size=args.train_batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
54+
55+
val_loader = data.DataLoader(
56+
scene_dataset(root_dir=args.root_dir,pathfile='./dataset/'+DataName+'_test.txt', transform=composed_transforms),
57+
batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
58+
59+
###################Network Definition###################
60+
if args.network=='alexnet':
61+
Model = models.alexnet(pretrained=True)
62+
Model.classifier._modules['6'] = nn.Linear(4096, num_classes)
63+
elif args.network=='vgg11':
64+
Model = models.vgg11(pretrained=True)
65+
Model.classifier._modules['6'] = nn.Linear(4096, num_classes)
66+
elif args.network=='vgg16':
67+
Model = models.vgg16(pretrained=True)
68+
Model.classifier._modules['6'] = nn.Linear(4096, num_classes)
69+
elif args.network=='vgg19':
70+
Model = models.vgg19(pretrained=True)
71+
Model.classifier._modules['6'] = nn.Linear(4096, num_classes)
72+
elif args.network=='inception':
73+
Model = models.inception_v3(pretrained=True, aux_logits=False)
74+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
75+
elif args.network=='resnet18':
76+
Model = models.resnet18(pretrained=True)
77+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
78+
elif args.network=='resnet50':
79+
Model = models.resnet50(pretrained=True)
80+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
81+
elif args.network=='resnet101':
82+
Model = models.resnet101(pretrained=True)
83+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
84+
elif args.network=='resnext50_32x4d':
85+
Model = models.resnext50_32x4d(pretrained=True)
86+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
87+
elif args.network=='resnext101_32x8d':
88+
Model = models.resnext101_32x8d(pretrained=True)
89+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
90+
elif args.network=='densenet121':
91+
Model = models.densenet121(pretrained=True)
92+
Model.classifier = nn.Linear(1024, num_classes)
93+
elif args.network=='densenet169':
94+
Model = models.densenet169(pretrained=True)
95+
Model.classifier = nn.Linear(1664, num_classes)
96+
elif args.network=='densenet201':
97+
Model = models.densenet201(pretrained=True)
98+
Model.classifier = nn.Linear(1920, num_classes)
99+
elif args.network=='regnet_x_400mf':
100+
Model = models.regnet_x_400mf(pretrained=True)
101+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
102+
elif args.network=='regnet_x_8gf':
103+
Model = models.regnet_x_8gf(pretrained=True)
104+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
105+
elif args.network=='regnet_x_16gf':
106+
Model = models.regnet_x_16gf(pretrained=True)
107+
Model.fc = torch.nn.Linear(Model.fc.in_features, num_classes)
108+
109+
Model = torch.nn.DataParallel(Model).cuda()
110+
Model_optimizer = torch.optim.Adam(Model.parameters(),lr=args.lr)
111+
num_batches = len(train_loader)
112+
113+
cls_loss = torch.nn.CrossEntropyLoss()
114+
num_steps = args.num_epochs*num_batches
115+
hist = np.zeros((num_steps,3))
116+
index_i = -1
117+
118+
for epoch in range(args.num_epochs):
119+
for batch_index, src_data in enumerate(train_loader):
120+
index_i += 1
121+
122+
tem_time = time.time()
123+
Model.train()
124+
Model_optimizer.zero_grad()
125+
126+
X_train, Y_train, _ = src_data
127+
X_train = X_train.cuda()
128+
Y_train = Y_train.cuda()
129+
130+
_,output = Model(X_train)
131+
132+
# CE Loss
133+
_, src_prd_label = torch.max(output, 1)
134+
cls_loss_value = cls_loss(output, Y_train)
135+
cls_loss_value.backward()
136+
137+
Model_optimizer.step()
138+
139+
hist[index_i,0] = time.time()-tem_time
140+
hist[index_i,1] = cls_loss_value.item()
141+
hist[index_i,2] = torch.mean((src_prd_label == Y_train).float()).item()
142+
143+
tem_time = time.time()
144+
if (batch_index+1) % print_per_batches == 0:
145+
print('Epoch %d/%d: %d/%d Time: %.2f cls_loss = %.3f acc = %.3f \n'\
146+
%(epoch+1, args.num_epochs,batch_index+1,num_batches,
147+
np.mean(hist[index_i-print_per_batches+1:index_i+1,0]),
148+
np.mean(hist[index_i-print_per_batches+1:index_i+1,1]),
149+
np.mean(hist[index_i-print_per_batches+1:index_i+1,2])))
150+
151+
152+
153+
OA_new,_ = test_acc(Model,classname, val_loader, epoch+1,num_classes,print_per_batches=10)
154+
155+
model_name = 'epoch_'+str(epoch+1)+'_OA_'+repr(int(OA_new*10000))+'.pth'
156+
157+
print('Save Model')
158+
torch.save(Model.state_dict(), os.path.join(save_path_prefix, model_name))
159+
160+
161+
162+
if __name__ == '__main__':
163+
parser = argparse.ArgumentParser()
164+
165+
parser.add_argument('--dataID', type=int, default=1)
166+
parser.add_argument('--network', type=str, default='resnet18',
167+
help='alexnet,vgg11,vgg16,vgg19,inception,resnet18,resnet50,resnet101,resnext50_32x4d,resnext101_32x8d,densenet121,densenet169,densenet201,regnet_x_400mf,regnet_x_8gf,regnet_x_16gf')
168+
parser.add_argument('--save_path_prefix', type=str, default='./')
169+
parser.add_argument('--root_dir', type=str, default='/iarai/home/yonghao.xu/Data/',help='dataset path.')
170+
parser.add_argument('--train_batch_size', type=int, default=64)
171+
parser.add_argument('--val_batch_size', type=int, default=64)
172+
parser.add_argument('--num_workers', type=int, default=1)
173+
parser.add_argument('--lr', type=float, default=1e-4)
174+
parser.add_argument('--crop_size', type=int, default=256)
175+
parser.add_argument('--num_epochs', type=int, default=10)
176+
parser.add_argument('--print_per_batches', type=int, default=5)
177+
178+
main(parser.parse_args())

0 commit comments

Comments
 (0)