-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathvisualisation.py
103 lines (84 loc) · 3.37 KB
/
visualisation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
# import matplotlib
# matplotlib.use('MACOSX')
import matplotlib.pyplot as plt
from smplx import SMPL
from utils.renderer import Renderer
from utils.cam_utils import perspective_project_torch
from data.ssp3d_dataset import SSP3DDataset
import config
# SMPL models in torch
smpl_male = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='male')
smpl_female = SMPL(config.SMPL_MODEL_DIR, batch_size=1, gender='female')
# Pyrender renderer
renderer = Renderer(faces=smpl_male.faces, img_res=512)
# SSP-3D datset class
ssp3d_dataset = SSP3DDataset(config.SSP_3D_PATH)
indices_to_plot = [11, 60, 199] # Visualising 3 examples from SSP-3D
for i in indices_to_plot:
data = ssp3d_dataset.__getitem__(i)
fname = data['fname']
image = data['image']
cropped_image = data['cropped_image']
silhouette = data['silhouette']
joints2D = data['joints2D']
body_shape = data['shape']
body_pose = data['pose']
gender = data['gender']
cam_trans = data['cam_trans']
# Obtaining body vertex mesh from SMPL shape and pose
body_shape = torch.from_numpy(body_shape[None, :]).float()
body_pose = torch.from_numpy(body_pose[None, :]).float()
cam_trans = torch.from_numpy(cam_trans[None, :]).float()
if gender == 'm':
smpl_output = smpl_male(body_pose=body_pose[:, 3:],
global_orient=body_pose[:, :3], # First 3 axis-angle pose parameters are global body rotation
betas=body_shape)
elif gender == 'f':
smpl_output = smpl_female(body_pose=body_pose[:, 3:],
global_orient=body_pose[:, :3], # First 3 axis-angle pose parameters are global body rotation
betas=body_shape)
vertices = smpl_output.vertices
projected_vertices = perspective_project_torch(vertices, cam_trans,
focal_length=config.FOCAL_LENGTH,
img_wh=512)
vertices = vertices.cpu().detach().numpy()[0]
projected_vertices = projected_vertices.cpu().detach().numpy()[0]
cam_trans = cam_trans.cpu().detach().numpy()[0]
# Rendering vertex mesh
rend_img = renderer(vertices, cam_trans, img=image)
# Visualise
fig = plt.figure(figsize=(14, 8))
fig.suptitle(fname)
plt.tight_layout()
plt.subplot(231)
plt.gca().set_title('Image')
plt.gca().set_axis_off()
plt.imshow(image)
plt.subplot(232)
plt.gca().set_title('Cropped Image')
plt.gca().set_axis_off()
plt.imshow(cropped_image)
plt.subplot(233)
plt.gca().set_title('2D Joints')
plt.gca().set_axis_off()
plt.imshow(image)
for j in range(joints2D.shape[0]):
plt.scatter(joints2D[j, 0], joints2D[j, 1], s=2, c='r')
plt.text(joints2D[j, 0], joints2D[j, 1], s=str(j))
plt.subplot(234)
plt.gca().set_title('Silhouette')
plt.gca().set_axis_off()
plt.imshow(image)
plt.imshow(silhouette, alpha=0.4)
plt.subplot(235)
plt.gca().set_title('Projected Vertices')
plt.gca().set_axis_off()
plt.imshow(image)
plt.scatter(projected_vertices[:, 0], projected_vertices[:, 1], s=1)
plt.subplot(236)
plt.gca().set_title('Body Render')
plt.gca().set_axis_off()
plt.imshow(rend_img)
plt.subplots_adjust(top=0.93, bottom=0.0, right=1, left=0, hspace=0.13, wspace=0)
plt.show()