-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathread_train_dataset.py
40 lines (30 loc) · 1.2 KB
/
read_train_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from os import listdir
import cv2
import numpy as np
import xml.etree.ElementTree as ET
def read_annotation(xml_file: str):
tree = ET.parse(xml_file)
root = tree.getroot()
bounding_box_list = []
file_name = root.find('filename').text
for obj in root.iter('object'):
object_label = obj.find("name").text
for box in obj.findall("bndbox"):
x_min = int(box.find("xmin").text)
y_min = int(box.find("ymin").text)
x_max = int(box.find("xmax").text)
y_max = int(box.find("ymax").text)
bounding_box = [object_label, x_min, y_min, x_max, y_max]
bounding_box_list.append(bounding_box)
return bounding_box_list, file_name
def read_train_dataset(dir):
images = []
annotations = []
for file in listdir(dir):
if 'jpg' in file.lower() or 'png' in file.lower():
images.append(cv2.imread(dir + file, 1))
annotation_file = file.replace(file.split('.')[-1], 'xml')
bounding_box_list, file_name = read_annotation(dir + annotation_file)
annotations.append((bounding_box_list, annotation_file, file_name))
images = np.array(images)
return images, annotations