Skip to content

Commit

Permalink
Add tranformation code for data preprocessing. Ref #4 #7
Browse files Browse the repository at this point in the history
  • Loading branch information
thss15fyt committed May 20, 2019
1 parent 4d5568e commit a102e65
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 3 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ For each data file `XXX.off` in ModelNet, we reorganize it to the format require
* The "face" part contains the center position, vertices' positions and normal vector of each face.
* The "neighbor_index" part contains the indices of neighbors of each face.

If you wish to create and use your own dataset, simplify your models and organize the `.off` files similar to the ModelNet dataset.
Then use the code in `data/preprocess.py` to transform them into the required `.npz` format.
Notice that the parameter `max_faces` in config files should be maximum number of faces among all of your simplified mesh models.

##### Train Model

To train and evaluate MeshNet for classification and retrieval:
Expand Down
1 change: 1 addition & 0 deletions config/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cuda_devices: '0'
dataset:
data_root: 'ModelNet40_MeshNet/'
augment_data: false
max_faces: 1024

# model
load_model: 'MeshNet_best_9192.pkl'
Expand Down
1 change: 1 addition & 0 deletions config/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cuda_devices: '0' # multi-gpu training is available
dataset:
data_root: 'ModelNet40_MeshNet/'
augment_data: true
max_faces: 1024

# result
ckpt_root: 'ckpt_root/'
Expand Down
7 changes: 4 additions & 3 deletions data/ModelNet40.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ModelNet40(data.Dataset):
def __init__(self, cfg, part='train'):
self.root = cfg['data_root']
self.augment_data = cfg['augment_data']
self.max_faces = cfg['max_faces']
self.part = part

self.data = []
Expand All @@ -42,12 +43,12 @@ def __getitem__(self, i):
jittered_data = np.clip(sigma * np.random.randn(*face[:, :12].shape), -1 * clip, clip)
face = np.concatenate((face[:, :12] + jittered_data, face[:, 12:]), 1)

# fill for n < 1024
# fill for n < max_faces with randomly picked faces
num_point = len(face)
if num_point < 1024:
if num_point < self.max_faces:
fill_face = []
fill_neighbor_index = []
for i in range(1024 - num_point):
for i in range(self.max_faces - num_point):
index = np.random.randint(0, num_point)
fill_face.append(face[index])
fill_neighbor_index.append(neighbor_index[index])
Expand Down
92 changes: 92 additions & 0 deletions data/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import glob as glob
import numpy as np
import os
import pymesh


def find_neighbor(faces, faces_contain_this_vertex, vf1, vf2, except_face):
for i in faces_contain_this_vertex[vf1] & faces_contain_this_vertex[vf2]:
if i != except_face:
face = faces[i].tolist()
face.remove(vf1)
face.remove(vf2)
return i

return except_face


if __name__ == '__main__':

root = 'ModelNet40_simplification'
new_root = 'ModelNet40_MeshNet'

for type in os.listdir(root):
for phrase in ['train', 'test']:
type_path = os.path.join(root, type)
phrase_path = os.path.join(type_path, phrase)
if not os.path.exists(type_path):
os.mkdir(os.path.join(new_root, type))
if not os.path.exists(phrase_path):
os.mkdir(phrase)

files = glob(os.path.join(phrase_path, '*.off'))
for file in files:
# load mesh
mesh = pymesh.load_mesh(file)

# clean up
mesh, _ = pymesh.remove_isolated_vertices(mesh)
mesh, _ = pymesh.remove_duplicated_vertices(mesh)

# get elements
vertices = mesh.vertices.copy()
faces = mesh.faces.copy()

# move to center
center = (np.max(vertices, 0) + np.min(vertices, 0)) / 2
vertices -= center

# normalize
max_len = np.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2)
vertices /= np.sqrt(max_len)

# get normal vector
mesh = pymesh.form_mesh(vertices, faces)
mesh.add_attribute('face_normal')
face_normal = mesh.get_face_attribute('face_normal')

# get neighbors
faces_contain_this_vertex = []
for i in len(vertices):
faces_contain_this_vertex.append(set([]))
centers = []
corners = []
for f in faces:
[v1, v2, v3] = f
x1, y1, z1 = vertices[v1]
x2, y2, z2 = vertices[v2]
x3, y3, z3 = vertices[v3]
centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3])
corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3])
faces_contain_this_vertex[v1].add(f)
faces_contain_this_vertex[v2].add(f)
faces_contain_this_vertex[v3].add(f)

neighbors = []
for i in range(len(faces)):
[v1, v2, v3] = faces[i]
n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i)
n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i)
n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i)
neighbors.append([n1, n2, n3])

centers = np.array(centers)
corners = np.array(corners)
faces = np.concatenate([centers, corners, face_normal], axis=1)
neighbors = np.array(neighbors)

_, filename = os.path.split(file)
np.savez(new_root + type + '/' + phrase + '/' + filename[:-4] + '.npz',
faces=faces, neighbors=neighbors)

print(file)

0 comments on commit a102e65

Please sign in to comment.