File tree 4 files changed +10
-11
lines changed
4 files changed +10
-11
lines changed Original file line number Diff line number Diff line change 46
46
img_score = img_out [target ].item ()
47
47
img_gt_score = img_gt_out [target ].item ()
48
48
if img_score < IMG_THRESH and img_gt_score >= OBJ_THRESH :
49
- print (target , img_out , img_gt_out )
50
49
if is_car :
51
50
car_imgs += 1
52
51
img .save ("coco_voc_images/car/{}.jpg" .format (ann ['id' ]))
Original file line number Diff line number Diff line change @@ -46,7 +46,6 @@ def parse_data(year):
46
46
img_score = img_out [target ].item ()
47
47
img_gt_score = img_gt_out [target ].item ()
48
48
if img_score < IMG_THRESH and img_gt_score >= OBJ_THRESH :
49
- print (target , img_out , img_gt_out )
50
49
if is_car :
51
50
car_imgs += 1
52
51
img .save ("coco_voc_images/car/{}" .format (filename ))
Original file line number Diff line number Diff line change 5
5
from numpy import argmax
6
6
from classifier .ResNet import ResNet
7
7
8
+ CONFIDENCE_THRESHOLD = 0.8
9
+
8
10
# load the pre-trained classifier (trained on imagenet)
9
11
classifier = ResNet ().to (device )
10
12
classifier .load_state_dict (torch .load ("classifier/init_model.pth" ))
@@ -76,7 +78,7 @@ def take_action(state, action):
76
78
conf_new = calculate_conf (next_state )
77
79
78
80
if done :
79
- if conf_new >= 0.9 :
81
+ if conf_new >= CONFIDENCE_THRESHOLD :
80
82
reward = 3.0
81
83
else :
82
84
reward = - 3.0
@@ -95,12 +97,13 @@ def find_positive_actions(state):
95
97
96
98
def find_best_action (state ):
97
99
confs = []
98
- if calculate_conf (state ) >= 0.9 :
100
+ if calculate_conf (state ) >= CONFIDENCE_THRESHOLD :
99
101
return 8
100
102
for i in range (8 ):
101
103
reward , next_state , done = take_action (state , i )
102
104
confs .append (calculate_conf (next_state ))
103
105
best_next_state_conf = argmax (confs )
106
+ #print([a.item() for a in confs])
104
107
if calculate_conf (state ) > confs [best_next_state_conf ]:
105
108
return None
106
109
return best_next_state_conf
Original file line number Diff line number Diff line change 7
7
MODEL_PATH = "models"
8
8
9
9
# Hyperparameters / utilities
10
- BATCH_SIZE = 5
11
- NUM_EPOCHS = 40
10
+ BATCH_SIZE = 10
11
+ NUM_EPOCHS = 100
12
12
GAMMA = 0.995
13
- EPS_START = 0.9
13
+ EPS_START = 1
14
14
EPS_END = 0.1
15
- EPS_LEN = 20 # number of epochs to decay epsilon
16
- TARGET_UPDATE = 10
15
+ EPS_LEN = 25 # number of epochs to decay epsilon
17
16
18
17
eps_sched = np .linspace (EPS_START , EPS_END , EPS_LEN )
19
18
@@ -40,7 +39,6 @@ def select_action(states, eps):
40
39
action = random .choice (positive_actions )
41
40
else :
42
41
action = random .randrange (9 )
43
- #action = random.randrange(9)
44
42
actions .append (action )
45
43
actions = torch .tensor (actions , device = device )
46
44
print ("random:" , actions )
@@ -68,7 +66,7 @@ def select_action(states, eps):
68
66
batch_steps = 0
69
67
start = time .time ()
70
68
# perform actions on batch items until done
71
- while len (states ) > 0 and batch_steps < 50 :
69
+ while len (states ) > 0 and batch_steps < 40 :
72
70
actions = select_action (states , eps )
73
71
states_new = []
74
72
# store state transition for each each (state, action) pair
You can’t perform that action at this time.
0 commit comments