Skip to content

Commit

Permalink
support training with raw .obj file
Browse files Browse the repository at this point in the history
  • Loading branch information
thss15fyt committed Dec 21, 2021
1 parent 2793763 commit 70f9115
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 5 deletions.
82 changes: 78 additions & 4 deletions data/ModelNet40.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import torch
import torch.utils.data as data
import pymeshlab
from data.preprocess import find_neighbor

type_to_index_map = {
'night_stand': 0, 'range_hood': 1, 'plant': 2, 'chair': 3, 'tent': 4,
Expand Down Expand Up @@ -33,14 +35,19 @@ def __init__(self, cfg, part='train'):
type_index = type_to_index_map[type]
type_root = os.path.join(os.path.join(self.root, type), part)
for filename in os.listdir(type_root):
if filename.endswith('.npz'):
if filename.endswith('.npz') or filename.endswith('.obj'):
self.data.append((os.path.join(type_root, filename), type_index))

def __getitem__(self, i):
path, type = self.data[i]
data = np.load(path)
face = data['faces']
neighbor_index = data['neighbors']
if path.endswith('.npz'):
data = np.load(path)
face = data['faces']
neighbor_index = data['neighbors']
else:
face, neighbor_index = process_mesh(path, self.max_faces)
if face is None:
return self.__getitem__(0)

# data augmentation
if self.augment_data and self.part == 'train':
Expand Down Expand Up @@ -74,3 +81,70 @@ def __getitem__(self, i):

def __len__(self):
return len(self.data)


def process_mesh(path, max_faces):
ms = pymeshlab.MeshSet()
ms.clear()

# load mesh
ms.load_new_mesh(path)
mesh = ms.current_mesh()

# # clean up
# mesh, _ = pymesh.remove_isolated_vertices(mesh)
# mesh, _ = pymesh.remove_duplicated_vertices(mesh)

# get elements
vertices = mesh.vertex_matrix()
faces = mesh.face_matrix()

if faces.shape[0] != max_faces: # only occur once in train set of Manifold40
print("Model with more than {} faces ({}): {}".format(max_faces, faces.shape[0], path))
return None, None

# move to center
center = (np.max(vertices, 0) + np.min(vertices, 0)) / 2
vertices -= center

# normalize
max_len = np.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2)
vertices /= np.sqrt(max_len)

# get normal vector
ms.clear()
mesh = pymeshlab.Mesh(vertices, faces)
ms.add_mesh(mesh)
face_normal = ms.current_mesh().face_normal_matrix()

# get neighbors
faces_contain_this_vertex = []
for i in range(len(vertices)):
faces_contain_this_vertex.append(set([]))
centers = []
corners = []
for i in range(len(faces)):
[v1, v2, v3] = faces[i]
x1, y1, z1 = vertices[v1]
x2, y2, z2 = vertices[v2]
x3, y3, z3 = vertices[v3]
centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3])
corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3])
faces_contain_this_vertex[v1].add(i)
faces_contain_this_vertex[v2].add(i)
faces_contain_this_vertex[v3].add(i)

neighbors = []
for i in range(len(faces)):
[v1, v2, v3] = faces[i]
n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i)
n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i)
n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i)
neighbors.append([n1, n2, n3])

centers = np.array(centers)
corners = np.array(corners)
faces = np.concatenate([centers, corners, face_normal], axis=1)
neighbors = np.array(neighbors)

return faces, neighbors
1 change: 0 additions & 1 deletion data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def find_neighbor(faces, faces_contain_this_vertex, vf1, vf2, except_face):

return except_face


if __name__ == '__main__':
root = Path('dataset/Manifold40')
new_root = Path('dataset/ModelNet40_processed')
Expand Down

0 comments on commit 70f9115

Please sign in to comment.