-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathfigure_animation.py
69 lines (66 loc) · 3.64 KB
/
figure_animation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os,copy,glob,json
import easyocr,onnxruntime,onnx
from utils.classification.figure_classification import *
from utils.GDINO.groundingdino.util.inference import load_model, load_image, predict, annotate, Model
from utils.animation.get_gif import *
from utils.animation.table_animation import *
from utils.animation.select_presentation import *
if __name__ == "__main__":
#Parser
parser = argparse.ArgumentParser("FIGA", add_help=True)
## GSAM
parser.add_argument("--config_file", type=str, default="utils/GDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", help="path to config file")
parser.add_argument("--grounded_checkpoint", default="utils/pretrained_models/GSAM/groundingdino_swint_ogc.pth",type=str, help="path to checkpoint file")
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False")
##DocFigure classification
parser.add_argument("--classifier", type=str, default="utils/pretrained_models/classification/figure_classification.onnx")
#animation
parser.add_argument("--duration", type=int, default=100, help="duration")
#IO
parser.add_argument("--input", type=str, required=True, help="path to image file")
parser.add_argument("--output_dir", type=str, default="outputs3", help="output directory")
parser.add_argument("--presentation_dir", type=str, default="presentation", help="output directory for presentation")
args = parser.parse_args()
# make dir
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.presentation_dir, exist_ok=True)
#classification model initial
# fig_model, fig_class_trasform = fig_classification(args.classifier,args.device)
fig_ort_session = onnxruntime.InferenceSession(args.classifier)
#GDINO initial
model = load_model(args.config_file, args.grounded_checkpoint, args.device)
#OCR initial
reader = easyocr.Reader(['ch_sim','en'],gpu=True)
#get image path
image_files = [file for file in os.listdir(args.input) if file.endswith((".jpg", ".png", ".jpeg"))]
figure_type_list=[]
#Processing
for img_path in image_files:
#figure type detection
img_file=os.path.join(args.input,img_path)
print(f"Processing {img_path}.")
# figure_type=figure_type_detection(fig_model, fig_class_trasform, img_file, args.device)
figure_type=figure_type_detection(fig_ort_session,img_file)
figure_type_list.append(figure_type)
# print(figure_type)
if figure_type not in ['Tables','Algorithm','Scatter plot','Pie chart','Graph plots','Bar plots','Box plot','Histogram','Confusion matrix']:
continue
if figure_type=='Tables' or figure_type=='Algorithm':
table_gif(img_file,args.duration,reader,args.output_dir)
continue
if figure_type=='scatter':
get_gif(img_file,args.output_dir,figure_type,args.duration,[])
continue
#figure obeject detection
text_prompt = get_prompt(figure_type)
image_pil, image = load_image(img_file)
# run grounding dino model
boxes, logits, phrases = predict(model,image,text_prompt,args.box_threshold,args.text_threshold,args.device)
detected_boxes=get_box(image_pil,copy.deepcopy(boxes))
#
get_gif(img_file,args.output_dir,figure_type,args.duration,detected_boxes)
#figure animation
select_for_presentation(args.input,args.output_dir,figure_type_list,image_files,args.presentation_dir)