diff --git a/.gitignore b/.gitignore index 660734b..2c95495 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ .idea/ -*.pkl \ No newline at end of file +ckpt_root/ +dataset/ +*.pkl +*/__pycache__/ \ No newline at end of file diff --git a/README.md b/README.md index 6054800..126e44c 100644 --- a/README.md +++ b/README.md @@ -1,54 +1,49 @@ ## MeshNet: Mesh Neural Network for 3D Shape Representation Created by Yutong Feng, Yifan Feng, Haoxuan You, Xibin Zhao, Yue Gao from Tsinghua University. -![pipeline](doc/pipeline.png) +![pipeline](doc/pipeline.PNG) ### Introduction -This work will appear in AAAI 2019. We proposed a novel framework (MeshNet) for 3D shape representation, which could learn on mesh data directly and achieve satisfying performance compared with traditional methods based on mesh and representative methods based on other types of data. You can also check out [paper](http://gaoyue.org/paper/MeshNet.pdf) for a deeper introduction. +This work was published in AAAI 2019. We proposed a novel framework (MeshNet) for 3D shape representation, which could learn on mesh data directly and achieve satisfying performance compared with traditional methods based on mesh and representative methods based on other types of data. You can also check out [paper](https://ojs.aaai.org/index.php/AAAI/article/view/4840/4713) for a deeper introduction. Mesh is an important and powerful type of data for 3D shapes. Due to the complexity and irregularity of mesh data, there is little effort on using mesh data for 3D shape representation in recent years. We propose a mesh neural network, named MeshNet, to learn 3D shape representation directly from mesh data. Face-unit and feature splitting are introduced to solve the complexity and irregularity problem. We have applied MeshNet in the applications of 3D shape classification and retrieval. Experimental results and comparisons with the state-of-the-art methods demonstrate that MeshNet can achieve satisfying 3D shape classification and retrieval performance, which indicates the effectiveness of the proposed method on 3D shape representation. In this repository, we release the code and data for train a Mesh Neural Network for classification and retrieval tasks on ModelNet40 dataset. -### Citation +### Update +**[2021/12]** We have released an updated version that the proposed MeshNet achieves 92.75% classification accuracy on ModelNet40. The results are based on a better simplified version of ModelNet40, named "Manifold40", with watertight mesh and 500 faces per model. We also provide a more stable training script to achieve the performance. See the Usage section for details. -if you find our work useful in your research, please consider citing: +### Usage +#### Installation +You could install the required package as follows. This code has been tested with Python 3.8 and CUDA 11.1. ``` -@article{feng2018meshnet, - title={MeshNet: Mesh Neural Network for 3D Shape Representation}, - author={Feng, Yutong and Feng, Yifan and You, Haoxuan and Zhao, Xibin and Gao, Yue}, - journal={AAAI 2019}, - year={2018} -} +pip install -r requirements.txt ``` -### Installation - -Install [PyTorch 0.4.0](https://pytorch.org). You also need to install yaml. The code has been tested with Python 3.6, PyTorch 0.4.0 and CUDA 9.0 on Ubuntu 16.04. - -### Usage - ##### Data Preparation - -Firstly, you should download the [reorganized ModelNet40 dataset](https://drive.google.com/open?id=1o9pyskkKMxuomI5BWuLjCG2nSv5iePZz). Then, configure the "data_root" in `config/train_config.yaml` and `config/test_config.yaml` with your path to the downloaded dataset: - +MeshNet requires the pre-processed ModelNet40 with simplified and re-organized mesh data. To quickly start training, we recommend to use our [pre-processed ModelNet40 dataset](https://cloud.tsinghua.edu.cn/f/77436a9afd294a52b492/?dl=1), and configure the "data_root" in `config/train_config.yaml` and `config/test_config.yaml` with your path to the downloaded dataset. By default, run ``` -# config/train_config.yaml and config/test_config.yaml -dataset: - data_root: [your_path_to_dataset] +wget --content-disposition https://cloud.tsinghua.edu.cn/f/77436a9afd294a52b492/?dl=1 +mkdir dataset +unzip -d dataset/ ModelNet40_processed.zip +rm ModelNet40_processed.zip ``` -For each data file `XXX.off` in ModelNet, we reorganize it to the format required by MeshNet and store it into `XXX.npz`. The reorganized file includes two parts of data: +The details of our pre-processing are as follows: The original dataset are from [ModelNet](http://modelnet.cs.princeton.edu/). Firstly, we simplify the mesh models with no more than `max_faces` faces. We now recommend to use the [Manifold40](https://cloud.tsinghua.edu.cn/f/2a292c598af94265a0b8/?dl=1) version with watertight mesh and `max_faces=500`. Then we reorganize the dataset to the format required by MeshNet and store it into `XXX.npz`. The reorganized file includes two parts of data: +- The "faces" part contains the center position, vertices' positions and normal vector of each face. +- The "neighbors" part contains the indices of neighbors of each face. -* The "face" part contains the center position, vertices' positions and normal vector of each face. -* The "neighbor_index" part contains the indices of neighbors of each face. +If you wish to create and use your own dataset, simplify your models into `.obj` format and use the code in `data/preprocess.py` to transform them into the required `.npz` format. Notice that the parameter `max_faces` in config files should be maximum number of faces among all of your simplified mesh models. -If you wish to create and use your own dataset, simplify your models and organize the `.off` files similar to the ModelNet dataset. -Then use the code in `data/preprocess.py` to transform them into the required `.npz` format. -Notice that the parameter `max_faces` in config files should be maximum number of faces among all of your simplified mesh models. +##### Evaluation +The pretrained MeshNet model weights are stored in [pretrained model](https://cloud.tsinghua.edu.cn/f/33bfdc6f103340daa86a/?dl=1). You can download it and configure the "load_model" in `config/test_config.yaml` with your path to the weight file. Then run the test script. +``` +wget --content-disposition https://cloud.tsinghua.edu.cn/f/33bfdc6f103340daa86a/?dl=1 +python test.py +``` -##### Train Model +##### Training To train and evaluate MeshNet for classification and retrieval: @@ -58,21 +53,23 @@ python train.py You can modify the configuration in the `config/train_config.yaml` for your own training, including the CUDA devices to use, the flag of data augmentation and the hyper-parameters of MeshNet. -##### Test Model - -The pretrained MeshNet model weights are stored in [pretrained model](https://drive.google.com/open?id=1l8Ij9BODxcD1goePBskPkBcgKW76Ewcs). You can download it and configure the "load_model" in `config/test_config.yaml` with your path to the weight file. -``` -# config/test_config.yaml -load_model: [your_path_to_weight_file] -``` +### Citation -To evaluate the model for classification and retrieval: +if you find our work useful in your research, please consider citing: -```bash -python test.py +``` +@inproceedings{feng2019meshnet, + title={Meshnet: Mesh neural network for 3d shape representation}, + author={Feng, Yutong and Feng, Yifan and You, Haoxuan and Zhao, Xibin and Gao, Yue}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={33}, + number={01}, + pages={8279--8286}, + year={2019} +} ``` ### Licence -Our code is released under MIT License (see LICENSE file for details). \ No newline at end of file +Our code is released under MIT License (see LICENSE file for details). diff --git a/config/config.py b/config/config.py index adf4e11..7d49f0a 100644 --- a/config/config.py +++ b/config/config.py @@ -14,7 +14,7 @@ def _check_dir(dir, make_dir=True): def get_train_config(config_file='config/train_config.yaml'): with open(config_file, 'r') as f: - cfg = yaml.load(f) + cfg = yaml.load(f, Loader=yaml.loader.SafeLoader) _check_dir(cfg['dataset']['data_root'], make_dir=False) _check_dir(cfg['ckpt_root']) @@ -24,7 +24,7 @@ def get_train_config(config_file='config/train_config.yaml'): def get_test_config(config_file='config/test_config.yaml'): with open(config_file, 'r') as f: - cfg = yaml.load(f) + cfg = yaml.load(f, Loader=yaml.loader.SafeLoader) _check_dir(cfg['dataset']['data_root'], make_dir=False) diff --git a/config/test_config.yaml b/config/test_config.yaml index dc109f7..16fa24e 100644 --- a/config/test_config.yaml +++ b/config/test_config.yaml @@ -3,12 +3,12 @@ cuda_devices: '0' # dataset dataset: - data_root: 'ModelNet40_MeshNet/' + data_root: 'dataset/ModelNet40_processed' augment_data: false - max_faces: 1024 + max_faces: 500 # model -load_model: 'MeshNet_best_9192.pkl' +load_model: 'MeshNet_ModelNet40_250e_bs128_lr6e-4.pkl' # MeshNet MeshNet: @@ -17,3 +17,10 @@ MeshNet: sigma: 0.2 mesh_convolution: aggregation_method: 'Concat' # Concat/Max/Average + mask_ratio: 0.95 + dropout: 0.5 + num_classes: 40 + +# test config +batch_size: 128 +retrieval_on: true diff --git a/config/train_config.yaml b/config/train_config.yaml index 13df564..9a5cb69 100644 --- a/config/train_config.yaml +++ b/config/train_config.yaml @@ -1,14 +1,16 @@ # CUDA -cuda_devices: '0' # multi-gpu training is available +cuda_devices: '0,1' # multi-gpu training is available # dataset dataset: - data_root: 'ModelNet40_MeshNet/' + data_root: 'dataset/ModelNet40_processed' + max_faces: 500 augment_data: true - max_faces: 1024 + jitter_sigma: 0.01 + jitter_clip: 0.05 # result -ckpt_root: 'ckpt_root/' +ckpt_root: 'ckpt_root' # MeshNet MeshNet: @@ -17,12 +19,20 @@ MeshNet: sigma: 0.2 mesh_convolution: aggregation_method: 'Concat' # Concat/Max/Average + mask_ratio: 0.95 + dropout: 0.5 + num_classes: 40 # train -lr: 0.01 +seed: 0 +lr: 0.0006 momentum: 0.9 weight_decay: 0.0005 -batch_size: 64 -max_epoch: 150 -milestones: [30, 60] +batch_size: 128 +max_epoch: 250 +optimizer: 'adamw' # sgd/adamw +scheduler: 'cos' # step/cos +milestones: [30, 60, 90] gamma: 0.1 +retrieval_on: true # enable evaluating retrieval performance during training +save_steps: 10 diff --git a/data/ModelNet40.py b/data/ModelNet40.py index 9049a1a..97123d4 100644 --- a/data/ModelNet40.py +++ b/data/ModelNet40.py @@ -19,12 +19,17 @@ class ModelNet40(data.Dataset): def __init__(self, cfg, part='train'): self.root = cfg['data_root'] - self.augment_data = cfg['augment_data'] self.max_faces = cfg['max_faces'] self.part = part + self.augment_data = cfg['augment_data'] + if self.augment_data: + self.jitter_sigma = cfg['jitter_sigma'] + self.jitter_clip = cfg['jitter_clip'] self.data = [] for type in os.listdir(self.root): + if type not in type_to_index_map.keys(): + continue type_index = type_to_index_map[type] type_root = os.path.join(os.path.join(self.root, type), part) for filename in os.listdir(type_root): @@ -34,14 +39,14 @@ def __init__(self, cfg, part='train'): def __getitem__(self, i): path, type = self.data[i] data = np.load(path) - face = data['face'] - neighbor_index = data['neighbor_index'] + face = data['faces'] + neighbor_index = data['neighbors'] # data augmentation if self.augment_data and self.part == 'train': - sigma, clip = 0.01, 0.05 - jittered_data = np.clip(sigma * np.random.randn(*face[:, :12].shape), -1 * clip, clip) - face = np.concatenate((face[:, :12] + jittered_data, face[:, 12:]), 1) + # jitter + jittered_data = np.clip(self.jitter_sigma * np.random.randn(*face[:, :3].shape), -1 * self.jitter_clip, self.jitter_clip) + face = np.concatenate((face[:, :3] + jittered_data, face[:, 3:]), 1) # fill for n < max_faces with randomly picked faces num_point = len(face) diff --git a/data/preprocess.py b/data/preprocess.py index f3523e4..3041599 100644 --- a/data/preprocess.py +++ b/data/preprocess.py @@ -1,7 +1,7 @@ -import glob as glob +import pymeshlab import numpy as np -import os -import pymesh +from pathlib import Path +from rich.progress import track def find_neighbor(faces, faces_contain_this_vertex, vf1, vf2, except_face): @@ -16,77 +16,77 @@ def find_neighbor(faces, faces_contain_this_vertex, vf1, vf2, except_face): if __name__ == '__main__': - - root = 'ModelNet40_simplification' - new_root = 'ModelNet40_MeshNet' - - for type in os.listdir(root): - for phrase in ['train', 'test']: - type_path = os.path.join(root, type) - phrase_path = os.path.join(type_path, phrase) - if not os.path.exists(type_path): - os.mkdir(os.path.join(new_root, type)) - if not os.path.exists(phrase_path): - os.mkdir(phrase) - - files = glob.glob(os.path.join(phrase_path, '*.off')) - for file in files: - # load mesh - mesh = pymesh.load_mesh(file) - - # clean up - mesh, _ = pymesh.remove_isolated_vertices(mesh) - mesh, _ = pymesh.remove_duplicated_vertices(mesh) - - # get elements - vertices = mesh.vertices.copy() - faces = mesh.faces.copy() - - # move to center - center = (np.max(vertices, 0) + np.min(vertices, 0)) / 2 - vertices -= center - - # normalize - max_len = np.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2) - vertices /= np.sqrt(max_len) - - # get normal vector - mesh = pymesh.form_mesh(vertices, faces) - mesh.add_attribute('face_normal') - face_normal = mesh.get_face_attribute('face_normal') - - # get neighbors - faces_contain_this_vertex = [] - for i in range(len(vertices)): - faces_contain_this_vertex.append(set([])) - centers = [] - corners = [] - for i in range(len(faces)): - [v1, v2, v3] = faces[i] - x1, y1, z1 = vertices[v1] - x2, y2, z2 = vertices[v2] - x3, y3, z3 = vertices[v3] - centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3]) - corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3]) - faces_contain_this_vertex[v1].add(i) - faces_contain_this_vertex[v2].add(i) - faces_contain_this_vertex[v3].add(i) - - neighbors = [] - for i in range(len(faces)): - [v1, v2, v3] = faces[i] - n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i) - n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i) - n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i) - neighbors.append([n1, n2, n3]) - - centers = np.array(centers) - corners = np.array(corners) - faces = np.concatenate([centers, corners, face_normal], axis=1) - neighbors = np.array(neighbors) - - _, filename = os.path.split(file) - np.savez(new_root + type + '/' + phrase + '/' + filename[:-4] + '.npz', - faces=faces, neighbors=neighbors) - - print(file) + root = Path('dataset/Manifold40') + new_root = Path('dataset/ModelNet40_processed') + max_faces = 500 + shape_list = sorted(list(root.glob('*/*/*.obj'))) + ms = pymeshlab.MeshSet() + + for shape_dir in track(shape_list): + out_dir = new_root / shape_dir.relative_to(root).with_suffix('.npz') + # if out_dir.exists(): + # continue + out_dir.parent.mkdir(parents=True, exist_ok=True) + + ms.clear() + # load mesh + ms.load_new_mesh(str(shape_dir)) + mesh = ms.current_mesh() + + # # clean up + # mesh, _ = pymesh.remove_isolated_vertices(mesh) + # mesh, _ = pymesh.remove_duplicated_vertices(mesh) + + # get elements + vertices = mesh.vertex_matrix() + faces = mesh.face_matrix() + + if faces.shape[0] != max_faces: + print("Model with more than {} faces ({}): {}".format(max_faces, faces.shape[0], out_dir)) + continue + + # move to center + center = (np.max(vertices, 0) + np.min(vertices, 0)) / 2 + vertices -= center + + # normalize + max_len = np.max(vertices[:, 0]**2 + vertices[:, 1]**2 + vertices[:, 2]**2) + vertices /= np.sqrt(max_len) + + # get normal vector + ms.clear() + mesh = pymeshlab.Mesh(vertices, faces) + ms.add_mesh(mesh) + face_normal = ms.current_mesh().face_normal_matrix() + + # get neighbors + faces_contain_this_vertex = [] + for i in range(len(vertices)): + faces_contain_this_vertex.append(set([])) + centers = [] + corners = [] + for i in range(len(faces)): + [v1, v2, v3] = faces[i] + x1, y1, z1 = vertices[v1] + x2, y2, z2 = vertices[v2] + x3, y3, z3 = vertices[v3] + centers.append([(x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3, (z1 + z2 + z3) / 3]) + corners.append([x1, y1, z1, x2, y2, z2, x3, y3, z3]) + faces_contain_this_vertex[v1].add(i) + faces_contain_this_vertex[v2].add(i) + faces_contain_this_vertex[v3].add(i) + + neighbors = [] + for i in range(len(faces)): + [v1, v2, v3] = faces[i] + n1 = find_neighbor(faces, faces_contain_this_vertex, v1, v2, i) + n2 = find_neighbor(faces, faces_contain_this_vertex, v2, v3, i) + n3 = find_neighbor(faces, faces_contain_this_vertex, v3, v1, i) + neighbors.append([n1, n2, n3]) + + centers = np.array(centers) + corners = np.array(corners) + faces = np.concatenate([centers, corners, face_normal], axis=1) + neighbors = np.array(neighbors) + + np.savez(str(out_dir), faces=faces, neighbors=neighbors) diff --git a/doc/pipeline.PNG b/doc/pipeline.PNG new file mode 100644 index 0000000..10d7d31 Binary files /dev/null and b/doc/pipeline.PNG differ diff --git a/doc/pipeline.png b/doc/pipeline.png deleted file mode 100644 index 22cd349..0000000 Binary files a/doc/pipeline.png and /dev/null differ diff --git a/models/MeshNet.py b/models/MeshNet.py index 82a26a2..1f28b54 100644 --- a/models/MeshNet.py +++ b/models/MeshNet.py @@ -23,14 +23,15 @@ def __init__(self, cfg, require_fea=False): nn.BatchNorm1d(1024), nn.ReLU(), ) + self.mask_ratio = cfg['mask_ratio'] self.classifier = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(), - nn.Dropout(p=0.5), + nn.Dropout(p=cfg['dropout']), nn.Linear(512, 256), nn.ReLU(), - nn.Dropout(p=0.5), - nn.Linear(256, 40) + nn.Dropout(p=cfg['dropout']), + nn.Linear(256, cfg['num_classes']) ) def forward(self, centers, corners, normals, neighbor_index): @@ -41,7 +42,9 @@ def forward(self, centers, corners, normals, neighbor_index): spatial_fea2, structural_fea2 = self.mesh_conv2(spatial_fea1, structural_fea1, neighbor_index) spatial_fea3 = self.fusion_mlp(torch.cat([spatial_fea2, structural_fea2], 1)) - fea = self.concat_mlp(torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1)) + fea = self.concat_mlp(torch.cat([spatial_fea1, spatial_fea2, spatial_fea3], 1)) # b, c, n + if self.training: + fea = fea[:, :, torch.randperm(fea.size(2))[:int(fea.size(2) * (1 - self.mask_ratio))]] fea = torch.max(fea, dim=2)[0] fea = fea.reshape(fea.size(0), -1) fea = self.classifier[:-1](fea) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1bc70ec --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch==1.8.0 +PyYAML==6.0 +pymeshlab==2021.10 +rich==10.16.1 +scipy diff --git a/test.py b/test.py index b48a961..cd02879 100644 --- a/test.py +++ b/test.py @@ -7,7 +7,7 @@ from config import get_test_config from data import ModelNet40 from models import MeshNet -from utils import append_feature, calculate_map +from utils.retrival import append_feature, calculate_map cfg = get_test_config() @@ -15,7 +15,7 @@ data_set = ModelNet40(cfg=cfg['dataset'], part='test') -data_loader = data.DataLoader(data_set, batch_size=1, num_workers=4, shuffle=True, pin_memory=False) +data_loader = data.DataLoader(data_set, batch_size=cfg['batch_size'], num_workers=4, shuffle=True, pin_memory=False) def test_model(model): @@ -23,24 +23,26 @@ def test_model(model): correct_num = 0 ft_all, lbl_all = None, None - for i, (centers, corners, normals, neighbor_index, targets) in enumerate(data_loader): - centers = Variable(torch.cuda.FloatTensor(centers.cuda())) - corners = Variable(torch.cuda.FloatTensor(corners.cuda())) - normals = Variable(torch.cuda.FloatTensor(normals.cuda())) - neighbor_index = Variable(torch.cuda.LongTensor(neighbor_index.cuda())) - targets = Variable(torch.cuda.LongTensor(targets.cuda())) + with torch.no_grad(): + for i, (centers, corners, normals, neighbor_index, targets) in enumerate(data_loader): + centers = centers.cuda() + corners = corners.cuda() + normals = normals.cuda() + neighbor_index = neighbor_index.cuda() + targets = targets.cuda() - outputs, feas = model(centers, corners, normals, neighbor_index) - _, preds = torch.max(outputs, 1) + outputs, feas = model(centers, corners, normals, neighbor_index) + _, preds = torch.max(outputs, 1) - if preds[0] == targets[0]: - correct_num += 1 + correct_num += (preds == targets).float().sum() - ft_all = append_feature(ft_all, feas.detach()) - lbl_all = append_feature(lbl_all, targets.detach(), flaten=True) + if cfg['retrieval_on']: + ft_all = append_feature(ft_all, feas.detach().cpu()) + lbl_all = append_feature(lbl_all, targets.detach().cpu(), flaten=True) print('Accuracy: {:.4f}'.format(float(correct_num) / len(data_set))) - print('mAP: {:.4f}'.format(calculate_map(ft_all, lbl_all))) + if cfg['retrieval_on']: + print('mAP: {:.4f}'.format(calculate_map(ft_all, lbl_all))) if __name__ == '__main__': diff --git a/train.py b/train.py index 243294d..0af21e0 100644 --- a/train.py +++ b/train.py @@ -1,20 +1,33 @@ import copy import os +import random import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim import torch.utils.data as data +import torch.backends.cudnn as cudnn +import math +import numpy as np from config import get_train_config from data import ModelNet40 from models import MeshNet -from utils import append_feature, calculate_map +from utils.retrival import append_feature, calculate_map cfg = get_train_config() os.environ['CUDA_VISIBLE_DEVICES'] = cfg['cuda_devices'] +# seed +seed = cfg['seed'] +random.seed(seed) +os.environ['PYTHONHASHSEED'] = str(seed) +np.random.seed(seed) +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +torch.cuda.manual_seed_all(seed) +# dataset data_set = { x: ModelNet40(cfg=cfg['dataset'], part=x) for x in ['train', 'test'] } @@ -36,10 +49,10 @@ def train_model(model, criterion, optimizer, scheduler, cfg): print('Epoch: {} / {}'.format(epoch, cfg['max_epoch'])) print('-' * 60) + # adjust_learning_rate(cfg, epoch, optimizer) for phrase in ['train', 'test']: if phrase == 'train': - scheduler.step() model.train() else: model.eval() @@ -49,14 +62,11 @@ def train_model(model, criterion, optimizer, scheduler, cfg): ft_all, lbl_all = None, None for i, (centers, corners, normals, neighbor_index, targets) in enumerate(data_loader[phrase]): - - optimizer.zero_grad() - - centers = Variable(torch.cuda.FloatTensor(centers.cuda())) - corners = Variable(torch.cuda.FloatTensor(corners.cuda())) - normals = Variable(torch.cuda.FloatTensor(normals.cuda())) - neighbor_index = Variable(torch.cuda.LongTensor(neighbor_index.cuda())) - targets = Variable(torch.cuda.LongTensor(targets.cuda())) + centers = centers.cuda() + corners = corners.cuda() + normals = normals.cuda() + neighbor_index = neighbor_index.cuda() + targets = targets.cuda() with torch.set_grad_enabled(phrase == 'train'): outputs, feas = model(centers, corners, normals, neighbor_index) @@ -64,12 +74,13 @@ def train_model(model, criterion, optimizer, scheduler, cfg): loss = criterion(outputs, targets) if phrase == 'train': + optimizer.zero_grad() loss.backward() optimizer.step() - if phrase == 'test': - ft_all = append_feature(ft_all, feas.detach()) - lbl_all = append_feature(lbl_all, targets.detach(), flaten=True) + if phrase == 'test' and cfg['retrieval_on']: + ft_all = append_feature(ft_all, feas.detach().cpu()) + lbl_all = append_feature(lbl_all, targets.detach().cpu(), flaten=True) running_loss += loss.item() * centers.size(0) running_corrects += torch.sum(preds == targets.data) @@ -79,31 +90,55 @@ def train_model(model, criterion, optimizer, scheduler, cfg): if phrase == 'train': print('{} Loss: {:.4f} Acc: {:.4f}'.format(phrase, epoch_loss, epoch_acc)) + scheduler.step() if phrase == 'test': - epoch_map = calculate_map(ft_all, lbl_all) if epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) - if epoch_map > best_map: - best_map = epoch_map - if epoch % 10 == 0: - torch.save(copy.deepcopy(model.state_dict()), 'ckpt_root/{}.pkl'.format(epoch)) - - print('{} Loss: {:.4f} Acc: {:.4f} mAP: {:.4f}'.format(phrase, epoch_loss, epoch_acc, epoch_map)) + print_info = '{} Loss: {:.4f} Acc: {:.4f} (best {:.4f})'.format(phrase, epoch_loss, epoch_acc, best_acc) + + if cfg['retrieval_on']: + epoch_map = calculate_map(ft_all, lbl_all) + if epoch_map > best_map: + best_map = epoch_map + print_info += ' mAP: {:.4f}'.format(epoch_map) + + if epoch % cfg['save_steps'] == 0: + torch.save(copy.deepcopy(model.state_dict()), os.path.join(cfg['ckpt_root'], '{}.pkl'.format(epoch))) + + print(print_info) + + print('Best val acc: {:.4f}'.format(best_acc)) + print('Config: {}'.format(cfg)) return best_model_wts if __name__ == '__main__': + # prepare model model = MeshNet(cfg=cfg['MeshNet'], require_fea=True) model.cuda() model = nn.DataParallel(model) + # criterion criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=cfg['lr'], momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) - scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones'], gamma=cfg['gamma']) + # optimizer + if cfg['optimizer'] == 'sgd': + optimizer = optim.SGD(model.parameters(), lr=cfg['lr'], momentum=cfg['momentum'], weight_decay=cfg['weight_decay']) + else: + optimizer = optim.AdamW(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay']) + + # scheduler + if cfg['scheduler'] == 'step': + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg['milestones']) + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['max_epoch']) + + # start training + if not os.path.exists(cfg['ckpt_root']): + os.mkdir(cfg['ckpt_root']) best_model_wts = train_model(model, criterion, optimizer, scheduler, cfg) - torch.save(best_model_wts, os.path.join(cfg['ckpt'], 'MeshNet_best.pkl')) + torch.save(best_model_wts, os.path.join(cfg['ckpt_root'], 'MeshNet_best.pkl')) diff --git a/utils/retrival.py b/utils/retrival.py index 802f4d4..3c4ea47 100644 --- a/utils/retrival.py +++ b/utils/retrival.py @@ -1,9 +1,7 @@ import os import numpy as np -import matplotlib.pyplot as plt - -# all 2468 shapes -top_k = 1000 +import scipy +import scipy.spatial def append_feature(raw, data, flaten=False): @@ -17,109 +15,54 @@ def append_feature(raw, data, flaten=False): return raw -def Eu_dis_mat_fast(X): - aa = np.sum(np.multiply(X, X), 1) - ab = X*X.T - D = aa+aa.T - 2*ab - D[D<0] = 0 - D = np.sqrt(D) - D = np.maximum(D, D.T) - return D - - def calculate_map(fts, lbls, dis_mat=None): - if dis_mat is None: - dis_mat = Eu_dis_mat_fast(np.mat(fts)) - num = len(lbls) - mAP = 0 - for i in range(num): - scores = dis_mat[:, i] - targets = (lbls == lbls[i]).astype(np.uint8) - sortind = np.argsort(scores, 0)[:top_k] - truth = targets[sortind] - sum = 0 - precision = [] - for j in range(top_k): - if truth[j]: - sum+=1 - precision.append(sum*1.0/(j + 1)) - if len(precision) == 0: - ap = 0 - else: - for ii in range(len(precision)): - precision[ii] = max(precision[ii:]) - ap = np.array(precision).mean() - mAP += ap - # print(f'{i+1}/{num}\tap:{ap:.3f}\t') - mAP = mAP/num - return mAP - + return map_score(fts, fts, lbls, lbls) -def cal_pr(cfg, des_mat, lbls, save=True, draw=False): - num = len(lbls) - precisions = [] - recalls = [] - ans = [] - for i in range(num): - scores = des_mat[:, i] - targets = (lbls == lbls[i]).astype(np.uint8) - sortind = np.argsort(scores, 0)[:top_k] - truth = targets[sortind] - tmp = 0 - sum = truth[:top_k].sum() - precision = [] - recall = [] - for j in range(top_k): - if truth[j]: - tmp+=1 - # precision.append(sum/(j + 1)) - recall.append(tmp*1.0/sum) - precision.append(tmp*1.0/(j+1)) - precisions.append(precision) - for j in range(len(precision)): - precision[j] = max(precision[j:]) - recalls.append(recall) - tmp = [] - for ii in range(11): - min_des = 100 - val = 0 - for j in range(top_k): - if abs(recall[j] - ii * 0.1) < min_des: - min_des = abs(recall[j] - ii * 0.1) - val = precision[j] - tmp.append(val) - print('%d/%d'%(i+1, num)) - ans.append(tmp) - ans = np.array(ans).mean(0) - if save: - save_dir = os.path.join(cfg.result_sub_folder, 'pr.csv') - np.savetxt(save_dir, np.array(ans), fmt='%.3f', delimiter=',') - if draw: - plt.plot(ans) - plt.show() +def acc_score(y_true, y_pred, average="micro"): + if isinstance(y_true, list): + y_true = np.array(y_true) + if isinstance(y_pred, list): + y_pred = np.array(y_pred) + if average == "micro": + # overall + return np.mean(y_true == y_pred) + elif average == "macro": + # average of each class + cls_acc = [] + for cls_idx in np.unique(y_true): + cls_acc.append(np.mean(y_pred[y_true==cls_idx]==cls_idx)) + return np.mean(np.array(cls_acc)) + else: + raise NotImplementedError -def test(): - scores = [0.23, 0.76, 0.01, 0.91, 0.13, 0.45, 0.12, 0.03, 0.38, 0.11, 0.03, 0.09, 0.65, 0.07, 0.12, 0.24, 0.1, 0.23, 0.46, 0.08] - gt_label = [0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1] - scores = np.array(scores) - targets = np.array(gt_label).astype(np.uint8) - sortind = np.argsort(scores, 0)[::-1] - truth = targets[sortind] - sum = 0 - precision = [] - for j in range(20): - if truth[j]: - sum += 1 - precision.append(sum / (j + 1)) - if len(precision) == 0: - ap = 0 +def cdist(fts_a, fts_b, metric): + if metric == 'inner': + return np.matmul(fts_a, fts_b.T) else: - for i in range(len(precision)): - precision[i] = max(precision[i:]) - ap = np.array(precision).mean() - print(ap) + return scipy.spatial.distance.cdist(fts_a, fts_b, metric) + +def map_score(fts_a, fts_b, lbl_a, lbl_b, metric='cosine'): + dist = cdist(fts_a, fts_b, metric) + res = map_from_dist(dist, lbl_a, lbl_b) + return res -if __name__ == '__main__': - test() +def map_from_dist(dist, lbl_a, lbl_b): + n_a, n_b = dist.shape + s_idx = dist.argsort() + + res = [] + for i in range(n_a): + order = s_idx[i] + p = 0.0 + r = 0.0 + for j in range(n_b): + if lbl_a[i] == lbl_b[order[j]]: + r += 1 + p += (r / (j + 1)) + if r > 0: + res.append(p/r) + else: + res.append(0) + return np.mean(res)