-
Notifications
You must be signed in to change notification settings - Fork 2
/
predict.py
40 lines (30 loc) · 1.06 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
from pathlib import Path
import torch
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
device = torch.device("cuda")
img_path = "dataset/UniMER-Test/cpe/0000013.png"
img = Image.open(img_path).convert("RGB")
img_stem = Path(img_path).stem
s1 = time.perf_counter()
print("Loading model")
model_path = "outputs/checkpoint-27738"
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-stage1")
model = VisionEncoderDecoderModel.from_pretrained(model_path).to(device)
print("Finished loading model.")
s2 = time.perf_counter()
pixel_values = processor(img, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
print(pixel_values.shape)
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
s3 = time.perf_counter()
print(generated_text)
print(f"loading_model: {s2 - s1}s")
print(f"infer: {s3 - s2}s")