-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
116 lines (89 loc) · 5.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
from smplx.lbs import transform_mat
def extract_cam_param_xml(xml_path='', dtype=torch.float32):
import xml.etree.ElementTree as ET
tree = ET.parse(xml_path)
extrinsics_mat = [float(s) for s in tree.find('./CameraMatrix/data').text.split()]
intrinsics_mat = [float(s) for s in tree.find('./Intrinsics/data').text.split()]
distortion_vec = [float(s) for s in tree.find('./Distortion/data').text.split()]
focal_length_x = intrinsics_mat[0]
focal_length_y = intrinsics_mat[4]
center = torch.tensor([[intrinsics_mat[2], intrinsics_mat[5]]], dtype=dtype)
rotation = torch.tensor([[extrinsics_mat[0], extrinsics_mat[1], extrinsics_mat[2]],
[extrinsics_mat[4], extrinsics_mat[5], extrinsics_mat[6]],
[extrinsics_mat[8], extrinsics_mat[9], extrinsics_mat[10]]], dtype=dtype)
translation = torch.tensor([[extrinsics_mat[3], extrinsics_mat[7], extrinsics_mat[11]]], dtype=dtype)
# t = -Rc --> c = -R^Tt
cam_center = [ -extrinsics_mat[0]*extrinsics_mat[3] - extrinsics_mat[4]*extrinsics_mat[7] - extrinsics_mat[8]*extrinsics_mat[11],
-extrinsics_mat[1]*extrinsics_mat[3] - extrinsics_mat[5]*extrinsics_mat[7] - extrinsics_mat[9]*extrinsics_mat[11],
-extrinsics_mat[2]*extrinsics_mat[3] - extrinsics_mat[6]*extrinsics_mat[7] - extrinsics_mat[10]*extrinsics_mat[11]]
cam_center = torch.tensor([cam_center], dtype=dtype)
k1 = torch.tensor([distortion_vec[0]], dtype=dtype)
k2 = torch.tensor([distortion_vec[1]], dtype=dtype)
return focal_length_x, focal_length_y, center, rotation, translation, cam_center, k1, k2
class CalibratedCamera(nn.Module):
def __init__(self, calib_path='', rotation=None, translation=None,
focal_length_x=None, focal_length_y=None,
batch_size=1,
center=None, dtype=torch.float32, **kwargs):
super(CalibratedCamera, self).__init__()
self.batch_size = batch_size
self.dtype = dtype
self.calib_path = calib_path
# Make a buffer so that PyTorch does not complain when creating
# the camera matrix
self.register_buffer('zero',
torch.zeros([batch_size], dtype=dtype))
import os.path as osp
if not osp.exists(calib_path):
raise FileNotFoundError('Could''t find {}.'.format(calib_path))
else:
focal_length_x, focal_length_y, center, rotation, translation, cam_center, _, _ \
= extract_cam_param_xml(xml_path=calib_path, dtype=dtype)
if focal_length_x is None or type(focal_length_x) == float:
focal_length_x = torch.full(
[batch_size],
focal_length_x,
dtype=dtype)
if focal_length_y is None or type(focal_length_y) == float:
focal_length_y = torch.full(
[batch_size],
focal_length_y,
dtype=dtype)
self.register_buffer('focal_length_x', focal_length_x)
self.register_buffer('focal_length_y', focal_length_y)
if center is None:
center = torch.zeros([batch_size, 2], dtype=dtype)
self.register_buffer('center', center)
rotation = rotation.unsqueeze(dim=0).repeat(batch_size, 1, 1)
rotation = nn.Parameter(rotation, requires_grad=False)
self.register_parameter('rotation', rotation)
if translation is None:
translation = torch.zeros([batch_size, 3], dtype=dtype)
translation = translation.view(3, -1).repeat(batch_size, 1, 1).squeeze(dim=-1)
translation = nn.Parameter(translation, requires_grad=False)
self.register_parameter('translation', translation)
cam_center = nn.Parameter(cam_center, requires_grad=False)
self.register_parameter('cam_center', cam_center)
def forward(self, points):
device = points.device
with torch.no_grad():
camera_mat = torch.zeros([self.batch_size, 2, 2],
dtype=self.dtype, device=points.device)
camera_mat[:, 0, 0] = self.focal_length_x
camera_mat[:, 1, 1] = self.focal_length_y
camera_transform = transform_mat(self.rotation,
self.translation.unsqueeze(dim=-1))
homog_coord = torch.ones(list(points.shape)[:-1] + [1],
dtype=points.dtype,
device=device)
# Convert the points to homogeneous coordinates
points_h = torch.cat([points, homog_coord], dim=-1)
projected_points = torch.einsum('bki,bji->bjk',
[camera_transform, points_h])
img_points = torch.div(projected_points[:, :, :2],
projected_points[:, :, 2].unsqueeze(dim=-1))
img_points = torch.einsum('bki,bji->bjk', [camera_mat, img_points]) \
+ self.center.unsqueeze(dim=1)
return img_points