forked from zorzi-s/PolyWorldPretrainedNetwork
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
90 lines (72 loc) · 2.54 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import numpy as np
from scipy.optimize import linear_sum_assignment
def scores_to_permutations(scores):
"""
Input a batched array of scores and returns the hungarian optimized
permutation matrices.
"""
B, N, N = scores.shape
scores = scores.detach().cpu().numpy()
perm = np.zeros_like(scores)
for b in range(B):
r, c = linear_sum_assignment(-scores[b])
perm[b,r,c] = 1
return torch.tensor(perm)
def permutations_to_polygons(perm, graph, out='torch'):
B, N, N = perm.shape
def bubble_merge(poly):
s = 0
P = len(poly)
while s < P:
head = poly[s][-1]
t = s+1
while t < P:
tail = poly[t][0]
if head == tail:
poly[s] = poly[s] + poly[t][1:]
del poly[t]
poly = bubble_merge(poly)
P = len(poly)
t += 1
s += 1
return poly
diag = torch.logical_not(perm[:,range(N),range(N)])
batch = []
for b in range(B):
b_perm = perm[b]
b_graph = graph[b]
b_diag = diag[b]
idx = torch.arange(N)[b_diag]
if idx.shape[0] > 0:
# If there are vertices in the batch
b_perm = b_perm[idx,:]
b_graph = b_graph[idx,:]
b_perm = b_perm[:,idx]
first = torch.arange(idx.shape[0]).unsqueeze(1)
second = torch.argmax(b_perm, dim=1).unsqueeze(1).cpu()
polygons_idx = torch.cat((first, second), dim=1).tolist()
polygons_idx = bubble_merge(polygons_idx)
batch_poly = []
for p_idx in polygons_idx:
if out == 'torch':
batch_poly.append(b_graph[p_idx,:])
elif out == 'numpy':
batch_poly.append(b_graph[p_idx,:].numpy())
elif out == 'list':
g = b_graph[p_idx,:] * 300 / 320
g[:,0] = -g[:,0]
g = torch.fliplr(g)
batch_poly.append(g.tolist())
elif out == 'coco':
g = b_graph[p_idx,:] * 300 / 320
g = torch.fliplr(g)
batch_poly.append(g.view(-1).tolist())
else:
print("Indicate a valid output polygon format")
exit()
batch.append(batch_poly)
else:
# If the batch has no vertices
batch.append([])
return batch