Skip to content

Commit f133b3f

Browse files
authored
Fix bug in conversation agent (#347)
* fix conversation agent * fix out of index error
1 parent 0c533f2 commit f133b3f

File tree

4 files changed

+134
-70
lines changed

4 files changed

+134
-70
lines changed

vision_agent/agent/agent_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def format_conversation(chat: List[AgentMessage]) -> str:
160160
prompt = ""
161161
for chat_i in chat:
162162
if chat_i.role == "user" or chat_i.role == "coder":
163-
if "<final_code>" in chat_i.role:
163+
if "<final_code>" in chat_i.content:
164164
prompt += f"OBSERVATION: {chat_i.content}\n\n"
165165
elif chat_i.role == "user":
166166
prompt += f"USER: {chat_i.content}\n\n"

vision_agent/agent/vision_agent_coder_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def generate_code_from_plan(
443443

444444
# we don't need the user_interaction response for generating code since it's
445445
# already in the plan context
446-
while chat[-1].role != "user":
446+
while len(chat) > 0 and chat[-1].role != "user":
447447
chat.pop()
448448

449449
if not chat:

vision_agent/agent/vision_agent_planner_prompts_v2.py

Lines changed: 126 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
[suggestion 0]
5151
The image is very large and the items you need to detect are small.
5252
53-
Step 1: You should start by splitting the image into sections and runing the detection algorithm on each section:
53+
Step 1: You should start by splitting the image into overlapping sections and runing the detection algorithm on each section:
5454
5555
def subdivide_image(image):
5656
height, width, _ = image.shape
@@ -66,41 +66,96 @@ def subdivide_image(image):
6666
6767
get_tool_for_task('<your prompt here>', subdivide_image(image))
6868
69-
Step 2: Once you have the detections from each subdivided image, you will need to merge them back together to remove overlapping predictions:
70-
71-
def translate_ofset(bbox, offset_x, offset_y):
72-
return (bbox[0] + offset_x, bbox[1] + offset_y, bbox[2] + offset_x, bbox[3] + offset_y)
73-
74-
def bounding_boxes_overlap(bbox1, bbox2):
75-
if bbox1[2] < bbox2[0] or bbox2[0] > bbox1[2]:
76-
return False
77-
if bbox1[3] < bbox2[1] or bbox2[3] > bbox1[3]:
78-
return False
79-
return True
80-
81-
def merge_bounding_boxes(bbox1, bbox2):
82-
x_min = min(bbox1[0], bbox2[0])
83-
y_min = min(bbox1[1], bbox2[1])
84-
x_max = max(bbox1[2], bbox2[2])
85-
y_max = max(bbox1[3], bbox2[3])
86-
return (x_min, y_min, x_max, y_max)
87-
88-
def merge_bounding_box_list(bboxes):
89-
merged_bboxes = []
90-
while bboxes:
91-
bbox = bboxes.pop()
92-
overlap_found = False
93-
for i, other_bbox in enumerate(merged_bboxes):
94-
if bounding_boxes_overlap(bbox, other_bbox):
95-
merged_bboxes[i] = merge_bounding_boxes(bbox, other_bbox)
96-
overlap_found = True
69+
Step 2: Once you have the detections from each subdivided image, you will need to merge them back together to remove overlapping predictions, be sure to tranlate the offset back to the original image:
70+
71+
def bounding_box_match(b1: List[float], b2: List[float], iou_threshold: float = 0.1) -> bool:
72+
# Calculate intersection coordinates
73+
x1 = max(b1[0], b2[0])
74+
y1 = max(b1[1], b2[1])
75+
x2 = min(b1[2], b2[2])
76+
y2 = min(b1[3], b2[3])
77+
78+
# Calculate intersection area
79+
if x2 < x1 or y2 < y1:
80+
return False # No overlap
81+
82+
intersection = (x2 - x1) * (y2 - y1)
83+
84+
# Calculate union area
85+
area1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
86+
area2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
87+
union = area1 + area2 - intersection
88+
89+
# Calculate IoU
90+
iou = intersection / union if union > 0 else 0
91+
92+
return iou >= iou_threshold
93+
94+
def merge_bounding_box_list(detections):
95+
merged_detections = []
96+
for detection in detections:
97+
matching_box = None
98+
for i, other in enumerate(merged_detections):
99+
if bounding_box_match(detection["bbox"], other["bbox"]):
100+
matching_box = i
97101
break
98-
if not overlap_found:
99-
p
100-
merged_bboxes.append(bbox)
101-
return merged_bboxes
102102
103-
detection = merge_bounding_box_list(detection_from_subdivided_images)
103+
if matching_box is not None:
104+
# Keep the box with higher confidence score
105+
if detection["score"] > merged_detections[matching_box]["score"]:
106+
merged_detections[matching_box] = detection
107+
else:
108+
merged_detections.append(detection)
109+
110+
def sub_image_to_original(elt, sub_image_position, original_size):
111+
offset_x, offset_y = sub_image_position
112+
return {
113+
"label": elt["label"],
114+
"score": elt["score"],
115+
"bbox": [
116+
(elt["bbox"][0] + offset_x) / original_size[1],
117+
(elt["bbox"][1] + offset_y) / original_size[0],
118+
(elt["bbox"][2] + offset_x) / original_size[1],
119+
(elt["bbox"][3] + offset_y) / original_size[0],
120+
],
121+
}
122+
123+
def normalized_to_unnormalized(elt, image_size):
124+
return {
125+
"label": elt["label"],
126+
"score": elt["score"],
127+
"bbox": [
128+
elt["bbox"][0] * image_size[1],
129+
elt["bbox"][1] * image_size[0],
130+
elt["bbox"][2] * image_size[1],
131+
elt["bbox"][3] * image_size[0],
132+
],
133+
}
134+
135+
height, width, _ = image.shape
136+
mid_width = width // 2
137+
mid_height = height // 2
138+
139+
detection_from_subdivided_images = []
140+
for i, sub_image in enumerate(subdivided_images):
141+
detections = <your detection function here>("pedestrian", sub_image)
142+
unnorm_detections = [
143+
normalized_to_unnormalized(
144+
detection, (sub_image.shape[0], sub_image.shape[1])
145+
)
146+
for detection in detections
147+
]
148+
offset_x = i % 2 * (mid_width - int(mid_width * 0.1))
149+
offset_y = i // 2 * (mid_height - int(mid_height * 0.1))
150+
offset_detections = [
151+
sub_image_to_original(
152+
unnorm_detection, (offset_x, offset_y), (height, width)
153+
)
154+
for unnorm_detection in unnorm_detections
155+
]
156+
detection_from_subdivided_images.extend(offset_detections)
157+
158+
detections = merge_bounding_box_list(detection_from_subdivided_images)
104159
[end of suggestion 0]
105160
[end of suggestion]
106161
<count>9</count>
@@ -164,36 +219,44 @@ def subdivide_image(image):
164219
165220
AGENT: <thinking>I need to now merge the boxes from all region and use the countgd_object_detection tool with the prompt 'pedestrian' as suggested by get_tool_for_task.</thinking>
166221
<execute_python>
167-
def translate_ofset(bbox, offset_x, offset_y):
168-
return (bbox[0] + offset_x, bbox[1] + offset_y, bbox[2] + offset_x, bbox[3] + offset_y)
169-
170-
def bounding_boxes_overlap(bbox1, bbox2):
171-
if bbox1[2] < bbox2[0] or bbox2[0] > bbox1[2]:
172-
return False
173-
if bbox1[3] < bbox2[1] or bbox2[3] > bbox1[3]:
174-
return False
175-
return True
176-
177-
def merge_bounding_boxes(bbox1, bbox2):
178-
x_min = min(bbox1[0], bbox2[0])
179-
y_min = min(bbox1[1], bbox2[1])
180-
x_max = max(bbox1[2], bbox2[2])
181-
y_max = max(bbox1[3], bbox2[3])
182-
return (x_min, y_min, x_max, y_max)
183-
184-
def merge_bounding_box_list(bboxes):
185-
merged_bboxes = []
186-
while bboxes:
187-
bbox = bboxes.pop()
188-
overlap_found = False
189-
for i, other_bbox in enumerate(merged_bboxes):
190-
if bounding_boxes_overlap(bbox, other_bbox):
191-
merged_bboxes[i] = merge_bounding_boxes(bbox, other_bbox)
192-
overlap_found = True
222+
def bounding_box_match(b1: List[float], b2: List[float], iou_threshold: float = 0.1) -> bool:
223+
# Calculate intersection coordinates
224+
x1 = max(b1[0], b2[0])
225+
y1 = max(b1[1], b2[1])
226+
x2 = min(b1[2], b2[2])
227+
y2 = min(b1[3], b2[3])
228+
229+
# Calculate intersection area
230+
if x2 < x1 or y2 < y1:
231+
return False # No overlap
232+
233+
intersection = (x2 - x1) * (y2 - y1)
234+
235+
# Calculate union area
236+
area1 = (b1[2] - b1[0]) * (b1[3] - b1[1])
237+
area2 = (b2[2] - b2[0]) * (b2[3] - b2[1])
238+
union = area1 + area2 - intersection
239+
240+
# Calculate IoU
241+
iou = intersection / union if union > 0 else 0
242+
243+
return iou >= iou_threshold
244+
245+
def merge_bounding_box_list(detections):
246+
merged_detections = []
247+
for detection in detections:
248+
matching_box = None
249+
for i, other in enumerate(merged_detections):
250+
if bounding_box_match(detection["bbox"], other["bbox"]):
251+
matching_box = i
193252
break
194-
if not overlap_found:
195-
merged_bboxes.append(bbox)
196-
return merged_bboxes
253+
254+
if matching_box is not None:
255+
# Keep the box with higher confidence score
256+
if detection["score"] > merged_detections[matching_box]["score"]:
257+
merged_detections[matching_box] = detection
258+
else:
259+
merged_detections.append(detection)
197260
198261
detections = []
199262
for region in subdivide_image(image):

vision_agent/agent/vision_agent_v2.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
def extract_conversation(
30-
chat: List[AgentMessage],
30+
chat: List[AgentMessage], include_conv: bool = False
3131
) -> Tuple[List[AgentMessage], Optional[str]]:
3232
chat = copy.deepcopy(chat)
3333

@@ -43,6 +43,8 @@ def extract_conversation(
4343
elif chat_i.role == "coder":
4444
if "<final_code>" in chat_i.content:
4545
extracted_chat.append(chat_i)
46+
elif include_conv and chat_i.role == "conversation":
47+
extracted_chat.append(chat_i)
4648

4749
# only keep the last <final_code> and <final_test>
4850
final_code = None
@@ -64,10 +66,9 @@ def extract_conversation(
6466

6567

6668
def run_conversation(agent: LMM, chat: List[AgentMessage]) -> str:
67-
extracted_chat, _ = extract_conversation(chat)
68-
extracted_chat = extracted_chat[-10:]
69+
extracted_chat, _ = extract_conversation(chat, include_conv=True)
6970

70-
conv = format_conversation(chat)
71+
conv = format_conversation(extracted_chat)
7172
prompt = CONVERSATION.format(
7273
conversation=conv,
7374
)
@@ -263,7 +264,7 @@ def chat(
263264
# do not append updated_chat to return_chat becuase the observation
264265
# from running the action will have already been added via the callbacks
265266
obs_response_context = run_conversation(
266-
self.agent, return_chat + updated_chat
267+
self.agent, int_chat + return_chat + updated_chat
267268
)
268269
return_chat.append(
269270
AgentMessage(role="conversation", content=obs_response_context)

0 commit comments

Comments
 (0)