Skip to content

Commit 5301022

Browse files
committed
Hyperparameters.
1 parent ba6bc8d commit 5301022

File tree

4 files changed

+10
-11
lines changed

4 files changed

+10
-11
lines changed

extract_coco_data.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
img_score = img_out[target].item()
4747
img_gt_score = img_gt_out[target].item()
4848
if img_score < IMG_THRESH and img_gt_score >= OBJ_THRESH:
49-
print(target, img_out, img_gt_out)
5049
if is_car:
5150
car_imgs += 1
5251
img.save("coco_voc_images/car/{}.jpg".format(ann['id']))

extract_voc_data.py

-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def parse_data(year):
4646
img_score = img_out[target].item()
4747
img_gt_score = img_gt_out[target].item()
4848
if img_score < IMG_THRESH and img_gt_score >= OBJ_THRESH:
49-
print(target, img_out, img_gt_out)
5049
if is_car:
5150
car_imgs += 1
5251
img.save("coco_voc_images/car/{}".format(filename))

reinforcement.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from numpy import argmax
66
from classifier.ResNet import ResNet
77

8+
CONFIDENCE_THRESHOLD = 0.8
9+
810
# load the pre-trained classifier (trained on imagenet)
911
classifier = ResNet().to(device)
1012
classifier.load_state_dict(torch.load("classifier/init_model.pth"))
@@ -76,7 +78,7 @@ def take_action(state, action):
7678
conf_new = calculate_conf(next_state)
7779

7880
if done:
79-
if conf_new >= 0.9:
81+
if conf_new >= CONFIDENCE_THRESHOLD:
8082
reward = 3.0
8183
else:
8284
reward = -3.0
@@ -95,12 +97,13 @@ def find_positive_actions(state):
9597

9698
def find_best_action(state):
9799
confs = []
98-
if calculate_conf(state) >= 0.9:
100+
if calculate_conf(state) >= CONFIDENCE_THRESHOLD:
99101
return 8
100102
for i in range(8):
101103
reward, next_state, done = take_action(state, i)
102104
confs.append(calculate_conf(next_state))
103105
best_next_state_conf = argmax(confs)
106+
#print([a.item() for a in confs])
104107
if calculate_conf(state) > confs[best_next_state_conf]:
105108
return None
106109
return best_next_state_conf

train.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
MODEL_PATH = "models"
88

99
# Hyperparameters / utilities
10-
BATCH_SIZE = 5
11-
NUM_EPOCHS = 40
10+
BATCH_SIZE = 10
11+
NUM_EPOCHS = 100
1212
GAMMA = 0.995
13-
EPS_START = 0.9
13+
EPS_START = 1
1414
EPS_END = 0.1
15-
EPS_LEN = 20 # number of epochs to decay epsilon
16-
TARGET_UPDATE = 10
15+
EPS_LEN = 25 # number of epochs to decay epsilon
1716

1817
eps_sched = np.linspace(EPS_START, EPS_END, EPS_LEN)
1918

@@ -40,7 +39,6 @@ def select_action(states, eps):
4039
action = random.choice(positive_actions)
4140
else:
4241
action = random.randrange(9)
43-
#action = random.randrange(9)
4442
actions.append(action)
4543
actions = torch.tensor(actions, device=device)
4644
print("random:", actions)
@@ -68,7 +66,7 @@ def select_action(states, eps):
6866
batch_steps = 0
6967
start = time.time()
7068
# perform actions on batch items until done
71-
while len(states) > 0 and batch_steps < 50:
69+
while len(states) > 0 and batch_steps < 40:
7270
actions = select_action(states, eps)
7371
states_new = []
7472
# store state transition for each each (state, action) pair

0 commit comments

Comments
 (0)