Skip to content

Commit

Permalink
Merge pull request #3 from Zerohertz/main
Browse files Browse the repository at this point in the history
Fix: Type of `depths` & Generalization according to parameters & some issues
  • Loading branch information
zhangjx123 authored Jun 29, 2023
2 parents 9fbd6dc + 45da880 commit 7cec76a
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 33 deletions.
19 changes: 10 additions & 9 deletions datasets/ocr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms, return_masks, dataset_name):
def __init__(self, img_folder, ann_file, transforms, return_masks, dataset_name, max_length):
super(CocoDetection, self).__init__(img_folder, ann_file)
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks, dataset_name)
self.prepare = ConvertCocoPolysToMask(return_masks, dataset_name, max_length)

def __getitem__(self, idx):
img, target = super(CocoDetection, self).__getitem__(idx)
Expand All @@ -37,9 +37,10 @@ def __getitem__(self, idx):


class ConvertCocoPolysToMask(object):
def __init__(self, return_masks=False, dataset_name=''):
def __init__(self, return_masks=False, dataset_name='', max_length=25):
self.return_masks = return_masks
self.dataset_name = dataset_name
self.max_length = max_length

def __call__(self, image, target):
w, h = image.size
Expand Down Expand Up @@ -80,13 +81,13 @@ def __call__(self, image, target):
target["orig_size"] = torch.as_tensor([int(h), int(w)])
target["size"] = torch.as_tensor([int(h), int(w)])

recog = [obj['rec'][:25] for obj in anno]
recog = torch.tensor(recog, dtype=torch.long).reshape(-1, 25)
target["rec"] = recog
recog = [obj['rec'][:self.max_length] for obj in anno]
recog = torch.tensor(recog, dtype=torch.long).reshape(-1, self.max_length)
target["rec"] = recog[keep]

bezier_pts = [obj['bezier_pts'] for obj in anno]
bezier_pts = torch.tensor(bezier_pts, dtype=torch.float32).reshape(-1, 16)
target['bezier_pts'] = bezier_pts
target['bezier_pts'] = bezier_pts[keep]
center_pts = torch.zeros(bezier_pts.shape[0], 2)
for i in range(bezier_pts.shape[0]):
tmp = bezier_pts[i]
Expand Down Expand Up @@ -141,7 +142,7 @@ def __call__(self, image, target):
center_pts[i][0] = xc
center_pts[i][1] = yc

target['center_pts'] = center_pts
target['center_pts'] = center_pts[keep]
assert target['center_pts'].shape[0] == target['bezier_pts'].shape[0]
return image, target

Expand Down Expand Up @@ -226,7 +227,7 @@ def build(image_set, args):
args.max_size_test, args.min_size_test, args.crop_min_ratio, args.crop_max_ratio,
args.crop_prob, args.rotate_max_angle, args.rotate_prob, args.brightness, args.contrast,
args.saturation, args.hue, args.distortion_prob)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms, return_masks=args.masks, dataset_name=dataset_name)
dataset = CocoDetection(img_folder, ann_file, transforms=transforms, return_masks=args.masks, dataset_name=dataset_name, max_length=args.max_length)
datasets.append(dataset)

if len(datasets) > 1:
Expand Down
12 changes: 6 additions & 6 deletions engine_sptsv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0,
lr_scheduler: list = [0], print_freq: int = 10):
lr_scheduler: list = [0], print_freq: int = 10, text_length: int = 25):
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
Expand All @@ -35,7 +35,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
if not all(input_label_seqs.tolist()):
continue
output_seqs = torch.cat([output_box_seqs.flatten(),output_label_seqs.flatten() ])
outputs_box, outputs_label = model(samples, input_box_seqs, input_label_seqs)
outputs_box, outputs_label = model(samples, input_box_seqs, input_label_seqs, text_length)
outputs_box = outputs_box.reshape(-1, outputs_box.shape[-1])
outputs_label = outputs_label.reshape(-1, outputs_label.shape[-1])
outputs = torch.cat([outputs_box,outputs_label],0)
Expand Down Expand Up @@ -74,7 +74,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,


@torch.no_grad()
def evaluate(model, criterion, data_loader, device, output_dir, chars, start_index, visualize=False):
def evaluate(model, criterion, data_loader, device, output_dir, chars, start_index, visualize=False, text_length=25):
model.eval()
criterion.eval()
chars = list(chars)
Expand All @@ -93,7 +93,7 @@ def evaluate(model, criterion, data_loader, device, output_dir, chars, start_ind
seq = torch.ones(len(targets), 1).to(samples.mask) * start_index
torch.cuda.synchronize()
t0 = time.time()
outputs = model(samples, seq,seq)
outputs = model(samples, seq,seq, text_length)
torch.cuda.synchronize()
t1 = time.time()
cnt += 1
Expand All @@ -104,7 +104,7 @@ def evaluate(model, criterion, data_loader, device, output_dir, chars, start_ind
outputs, values, rec_scores = outputs
if visualize:
samples_ = samples.to(torch.device('cpu')); outputs_ = outputs.cpu()
vis_images = vis_output_seqs(samples_, outputs_, rec_scores, False)
vis_images = vis_output_seqs(samples_, outputs_, rec_scores, False, True, text_length, chars)
for vis_image, target, dataset_name in zip(vis_images, targets, dataset_names):
save_path = os.path.join(output_dir, 'vis', dataset_name, '{:06d}.jpg'.format(target['image_id'].item()))
os.makedirs(os.path.dirname(save_path), exist_ok=True)
Expand All @@ -113,7 +113,7 @@ def evaluate(model, criterion, data_loader, device, output_dir, chars, start_ind
outputs = outputs.cpu(); values = values.cpu(); rec_scores = rec_scores.cpu()
for target, output, value, rec_score in zip(targets, outputs, values, rec_scores):
image_id = target['image_id'].item()
output, split_index = extract_result_from_output_seqs(output, rec_score, return_index=True)
output, split_index = extract_result_from_output_seqs(output, rec_score, return_index=True, text_length=text_length, chars=chars)
split_values = [value[split_index[i]:split_index[i+1]] for i in range(0, len(split_index)-1)]
center_pts = output['center_pts']; rec_labels = output['rec']; rec_scores = output['key_rec_score']
rec_labels = convert_rec_to_str(rec_labels, chars)
Expand Down
7 changes: 4 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_args_parser():
help="Size of the embeddings (dimension of the transformer)")
parser.add_argument('--dropout', default=0.1, type=float,
help="Dropout applied in the transformer")
parser.add_argument('--depths', default=[6], type=int,
parser.add_argument('--depths', default=6, type=int,
help="swin transformer structure")
parser.add_argument('--nheads', default=8, type=int,
help="Number of attention heads inside the transformer's attentions")
Expand Down Expand Up @@ -208,7 +208,8 @@ def main(args):
evaluate(model, criterion,
data_loader_val, device,
args.output_dir, args.chars,
args.start_index, args.visualize)
args.start_index, args.visualize,
args.max_length)
return

print("Start training")
Expand All @@ -225,7 +226,7 @@ def main(args):
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train, optimizer, device, epoch,
args.clip_max_norm, learning_rate_schedule, args.print_freq)
args.clip_max_norm, learning_rate_schedule, args.print_freq, args.max_length)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [output_dir / 'checkpoint.pth']
Expand Down
15 changes: 8 additions & 7 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_args_parser():
parser.add_argument('--img_path', default="", type=str, help='path for the image to be detected')
return parser

@torch.no_grad()
def main(args):
args = process_args(args)
device = torch.device(args.device)
Expand Down Expand Up @@ -106,18 +107,18 @@ def main(args):
model.eval()

# get predictions
output = model(image_new,seq,seq)
output = model(image_new,seq,seq, text_length=args.max_length)
outputs, values, _ = output
N = (outputs[0].shape[0])//27
N = (outputs[0].shape[0])//(args.max_length+2)
img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
for i in range(N):
v = values[0][27*i:(27)*i+27].mean().item()
v = values[0][(args.max_length+2)*i:((args.max_length+2))*i+(args.max_length+2)].mean().item()
if v > 0.922:
text = ''
pts_x = outputs[0][27*i].item() * (float(w_ori) / 1000)
pts_y = outputs[0][27*i+1].item() * (float(h_ori) / 1000)
for c in outputs[0][27*i+2:27*i+27].tolist():
if 1000 < c <1096:
pts_x = outputs[0][(args.max_length+2)*i].item() * (float(w_ori) / 1000)
pts_y = outputs[0][(args.max_length+2)*i+1].item() * (float(h_ori) / 1000)
for c in outputs[0][(args.max_length+2)*i+2:(args.max_length+2)*i+(args.max_length+2)].tolist():
if 1000 < c < 1000 + len(args.chars) + 1:
text += args.chars[c-1000]
else:
break
Expand Down
16 changes: 8 additions & 8 deletions util/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

def extract_result_from_output_seqs(
seqs, rec_score, key_pts='center_pts', key_label='rec', return_index=False,
end_index=1097, bins=1000, padding_bins=0, pad_rec=True
bins=1000, padding_bins=0, pad_rec=True, text_length=25, chars=[]
):
end_index = bins + len(chars) + 1
target = {
key_pts: [],
key_label: [],
Expand All @@ -19,7 +20,7 @@ def extract_result_from_output_seqs(
# pts_len = 16 if key_pts=='bezier_pts' else 4
pts_len = 2
category_start_index = bins + 2*padding_bins
rec_score = rec_score[:, category_start_index:category_start_index+95]
rec_score = rec_score[:, category_start_index:category_start_index+len(chars)]
split_index = [0, ]; index = 0; rec_index = 0
while(True):
if index >= len(seqs) or seqs[index] == end_index:
Expand All @@ -36,7 +37,7 @@ def extract_result_from_output_seqs(
break
index += 1
else:
label = seqs[index:index+25]; index += 25
label = seqs[index:index+text_length]; index += text_length
label = torch.clamp(label, min=category_start_index, max=end_index-1) - category_start_index
if end_index - 1 in label:
if torch.min(torch.where(label==end_index-1)[0]) == 0:
Expand All @@ -45,8 +46,8 @@ def extract_result_from_output_seqs(
split_index.append(index)
target[key_pts].append(pts)
target[key_label].append(label)
target['key_rec_score'].append(rec_score[rec_index:rec_index+25].softmax(-1))
rec_index = rec_index +25
target['key_rec_score'].append(rec_score[rec_index:rec_index+text_length].softmax(-1))
rec_index = rec_index +text_length
if return_index:
return target, split_index
return target
Expand Down Expand Up @@ -128,13 +129,12 @@ def draw_text(image, text, pt):
cv2.putText(image, text, (int(pt[0]), int(pt[1])), cv2.FONT_HERSHEY_COMPLEX, 1, (0, 255, 0), 2)
return image

def vis_output_seqs(samples, output_seqs, rec_scores, remove_padding=False, pad_rec=False):
chars = list(' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~')
def vis_output_seqs(samples, output_seqs, rec_scores, remove_padding=False, pad_rec=False, text_length=25, chars=[]):
tensors = samples.tensors
targets = []
# targets = [extract_result_from_output_seqs(ele, pad_rec=pad_rec) for ele in output_seqs]
for ele, rec_score in zip(output_seqs, rec_scores):
targets.append(extract_result_from_output_seqs(ele, rec_score, pad_rec=pad_rec))
targets.append(extract_result_from_output_seqs(ele, rec_score, pad_rec=pad_rec, text_length=text_length , chars=chars))
center_pts = [convert_pt_to_pixel(target['center_pts'], tensors.shape[2], tensors.shape[3]) for target in targets]
rec_labels = [convert_rec_to_str(target['rec'], chars) for target in targets]
vis_images = []
Expand Down

0 comments on commit 7cec76a

Please sign in to comment.