Skip to content

Commit

Permalink
Add HEDNet
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanggang001 committed Dec 28, 2023
1 parent 92419eb commit ed844ed
Show file tree
Hide file tree
Showing 15 changed files with 810 additions and 285 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ data/
venv/
*.idea/
*.so
*.yaml
*.sh
*.pth
*.pkl
Expand Down
302 changes: 35 additions & 267 deletions README.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pcdet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def cfg_from_list(cfg_list, config):

def merge_new_config(config, new_config):
if '_BASE_CONFIG_' in new_config:
with open(new_config['_BASE_CONFIG_'], 'r') as f:
with open(new_config['_BASE_CONFIG_'], 'r', encoding="utf-8") as f:
try:
yaml_config = yaml.safe_load(f, Loader=yaml.FullLoader)
except:
Expand All @@ -69,7 +69,7 @@ def merge_new_config(config, new_config):


def cfg_from_yaml_file(cfg_file, config):
with open(cfg_file, 'r') as f:
with open(cfg_file, 'r', encoding="utf-8") as f:
try:
new_config = yaml.safe_load(f, Loader=yaml.FullLoader)
except:
Expand Down
22 changes: 11 additions & 11 deletions pcdet/datasets/augmentor/database_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ def __getstate__(self):
def __setstate__(self, d):
self.__dict__.update(d)

def __del__(self):
if self.use_shared_memory:
self.logger.info('Deleting GT database from shared memory')
cur_rank, num_gpus = common_utils.get_dist_info()
sa_key = self.sampler_cfg.DB_DATA_PATH[0]
if cur_rank % num_gpus == 0 and os.path.exists(f"/dev/shm/{sa_key}"):
SharedArray.delete(f"shm://{sa_key}")

if num_gpus > 1:
dist.barrier()
self.logger.info('GT database has been removed from shared memory')
# def __del__(self):
# if self.use_shared_memory:
# self.logger.info('Deleting GT database from shared memory')
# cur_rank, num_gpus = common_utils.get_dist_info()
# sa_key = self.sampler_cfg.DB_DATA_PATH[0]
# if cur_rank % num_gpus == 0 and os.path.exists(f"/dev/shm/{sa_key}"):
# SharedArray.delete(f"shm://{sa_key}")

# if num_gpus > 1:
# dist.barrier()
# self.logger.info('GT database has been removed from shared memory')

def load_db_to_shared_memory(self):
self.logger.info('Loading GT database to shared memory')
Expand Down
17 changes: 16 additions & 1 deletion pcdet/datasets/waymo/waymo_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,25 @@ def waymo_eval(eval_det_annos, eval_gt_annos):
distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
)
ap_result_str = '\n'
for key in ap_dict:
overall_result = {}
for idx, key in enumerate(ap_dict):
level_metric = key.split('_')[5] # '1/AP', '2/AP', '1/APH', '2/APH'
key_overall = "LEVEL_" + level_metric + '_Overall'
if key_overall in overall_result.keys():
overall_result[key_overall]["value"] = overall_result[key_overall]["value"] + ap_dict[key][0]
overall_result[key_overall]["count"] = overall_result[key_overall]["count"] + 1
else:
overall_result[key_overall] = {}
overall_result[key_overall]["value"] = ap_dict[key][0]
overall_result[key_overall]["count"] = 1

ap_dict[key] = ap_dict[key][0]
ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])

for key in overall_result:
ap_dict[key] = overall_result[key]['value'] / overall_result[key]['count']
ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])

return ap_result_str, ap_dict

eval_det_annos = copy.deepcopy(det_annos)
Expand Down
2 changes: 1 addition & 1 deletion pcdet/datasets/waymo/waymo_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def run_eval_ops(
)

def eval_value_ops(self, sess, graph, metrics):
return {item[0]: sess.run([item[1][0]]) for item in metrics.items()}
return {item[0]: sess.run([item[1][0]]) for item in metrics.items() if 'SIGN' not in item[0] and 'APH' in item[0]} # only show APH, filter out 'SIGN'

def mask_by_distance(self, distance_thresh, boxes_3d, *args):
mask = np.linalg.norm(boxes_3d[:, 0:2], axis=1) < distance_thresh + 0.5
Expand Down
2 changes: 2 additions & 0 deletions pcdet/models/backbones_2d/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .base_bev_backbone import BaseBEVBackbone, BaseBEVBackboneV1, BaseBEVResBackbone
from .bev_backbone_ded import CascadeDEDBackbone

__all__ = {
'BaseBEVBackbone': BaseBEVBackbone,
'BaseBEVBackboneV1': BaseBEVBackboneV1,
'BaseBEVResBackbone': BaseBEVResBackbone,
'CascadeDEDBackbone': CascadeDEDBackbone,
}
89 changes: 89 additions & 0 deletions pcdet/models/backbones_2d/bev_backbone_ded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch.nn as nn
from .base_bev_backbone import BasicBlock


class DEDBackbone(nn.Module):

def __init__(self, model_cfg, input_channels):
super().__init__()

num_SBB = model_cfg.NUM_SBB
down_strides = model_cfg.DOWN_STRIDES
dim = model_cfg.FEATURE_DIM
assert len(num_SBB) == len(down_strides)

num_levels = len(down_strides)

first_block = []
if input_channels != dim:
first_block.append(BasicBlock(input_channels, dim, down_strides[0], 1, True))
first_block += [BasicBlock(dim, dim) for _ in range(num_SBB[0])]
self.encoder = nn.ModuleList([nn.Sequential(*first_block)])

for idx in range(1, num_levels):
cur_layers = [BasicBlock(dim, dim, down_strides[idx], 1, True)]
cur_layers.extend([BasicBlock(dim, dim) for _ in range(num_SBB[idx])])
self.encoder.append(nn.Sequential(*cur_layers))

self.decoder = nn.ModuleList()
self.decoder_norm = nn.ModuleList()
for idx in range(num_levels - 1, 0, -1):
self.decoder.append(
nn.Sequential(
nn.ConvTranspose2d(dim, dim, down_strides[idx], down_strides[idx], bias=False),
nn.BatchNorm2d(dim, eps=1e-3, momentum=0.01),
nn.ReLU()
)
)
self.decoder_norm.append(nn.BatchNorm2d(dim, eps=1e-3, momentum=0.01))

self.num_bev_features = dim
self.init_weights()

def init_weights(self):
for _, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out', nonlinearity='relu')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, data_dict):
x = data_dict['spatial_features']
x = self.encoder[0](x)

feats = [x]
for conv in self.encoder[1:]:
x = conv(x)
feats.append(x)

for deconv, norm, up_x in zip(self.decoder, self.decoder_norm, feats[:-1][::-1]):
x = norm(deconv(x) + up_x)

data_dict['spatial_features_2d'] = x
data_dict['spatial_features'] = x
return data_dict


class CascadeDEDBackbone(nn.Module):

def __init__(self, model_cfg, input_channels):
super().__init__()

num_layers = model_cfg.NUM_LAYERS

self.layers = nn.ModuleList()
for idx in range(num_layers):
input_dim = input_channels if idx == 0 else model_cfg.FEATURE_DIM
self.layers.append(DEDBackbone(model_cfg, input_dim))

self.num_bev_features = model_cfg.FEATURE_DIM

def forward(self, data_dict):
for layer in self.layers:
data_dict = layer(data_dict)
data_dict['spatial_features_2d'] = data_dict['spatial_features']
return data_dict
3 changes: 3 additions & 0 deletions pcdet/models/backbones_3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .spconv_backbone_voxelnext2d import VoxelResBackBone8xVoxelNeXt2D
from .spconv_unet import UNetV2
from .dsvt import DSVT
from .spconv_backbone_sed import HEDNet


__all__ = {
'VoxelBackBone8x': VoxelBackBone8x,
Expand All @@ -19,4 +21,5 @@
'PillarBackBone8x': PillarBackBone8x,
'PillarRes18BackBone8x': PillarRes18BackBone8x,
'DSVT': DSVT,
'HEDNet': HEDNet,
}
139 changes: 139 additions & 0 deletions pcdet/models/backbones_3d/spconv_backbone_sed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@

import torch.nn as nn
from functools import partial
from ...utils.spconv_utils import replace_feature, spconv
from .spconv_backbone import post_act_block, SparseBasicBlock


class SEDBlock(spconv.SparseModule):

def __init__(self, dim, kernel_size, stride, num_SBB, norm_fn, indice_key):
super(SEDBlock, self).__init__()

first_block = post_act_block(
dim, dim, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2,
norm_fn=norm_fn, indice_key=f'spconv_{indice_key}', conv_type='spconv')

block_list = [first_block if stride > 1 else nn.Identity()]
for _ in range(num_SBB):
block_list.append(
SparseBasicBlock(dim, dim, norm_fn=norm_fn, indice_key=indice_key))

self.blocks = spconv.SparseSequential(*block_list)

def forward(self, x):
return self.blocks(x)


class SEDLayer(spconv.SparseModule):

def __init__(self, dim: int, down_kernel_size: list, down_stride: list, num_SBB: list, norm_fn, indice_key):
super().__init__()

assert down_stride[0] == 1 # hard code
assert len(down_kernel_size) == len(down_stride) == len(num_SBB)

self.encoder = nn.ModuleList()
for idx in range(len(down_stride)):
self.encoder.append(
SEDBlock(dim, down_kernel_size[idx], down_stride[idx], num_SBB[idx], norm_fn, f"{indice_key}_{idx}"))

downsample_times = len(down_stride[1:])
self.decoder = nn.ModuleList()
self.decoder_norm = nn.ModuleList()
for idx, kernel_size in enumerate(down_kernel_size[1:]):
self.decoder.append(
post_act_block(
dim, dim, kernel_size, norm_fn=norm_fn, conv_type='inverseconv',
indice_key=f'spconv_{indice_key}_{downsample_times - idx}'))
self.decoder_norm.append(norm_fn(dim))

def forward(self, x):
features = []
for conv in self.encoder:
x = conv(x)
features.append(x)

x = features[-1]
for deconv, norm, up_x in zip(self.decoder, self.decoder_norm, features[:-1][::-1]):
x = deconv(x)
x = replace_feature(x, x.features + up_x.features)
x = replace_feature(x, norm(x.features))
return x


class HEDNet(nn.Module):

def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__()

self.sparse_shape = grid_size[::-1] + [1, 0, 0]
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)

dim = model_cfg.FEATURE_DIM
num_layers = model_cfg.NUM_LAYERS
num_SBB = model_cfg.NUM_SBB
down_kernel_size = model_cfg.DOWN_KERNEL_SIZE
down_stride = model_cfg.DOWN_STRIDE

# [1888, 1888, 41] -> [944, 944, 21]
self.conv1 = spconv.SparseSequential(
post_act_block(input_channels, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1', conv_type='subm'),
SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='stem'),
SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='stem'),
post_act_block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv1', conv_type='spconv'),
)

# [944, 944, 21] -> [472, 472, 11]
self.conv2 = spconv.SparseSequential(
SEDLayer(32, down_kernel_size, down_stride, num_SBB, norm_fn=norm_fn, indice_key='sedlayer2'),
post_act_block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'),
)

# [472, 472, 11] -> [236, 236, 11]
self.conv3 = spconv.SparseSequential(
SEDLayer(64, down_kernel_size, down_stride, num_SBB, norm_fn=norm_fn, indice_key='sedlayer3'),
post_act_block(64, dim, 3, norm_fn=norm_fn, stride=(1, 2, 2), padding=1, indice_key='spconv3', conv_type='spconv'),
)

self.layers = nn.ModuleList()
for idx in range(num_layers):
conv = SEDLayer(dim, down_kernel_size, down_stride, num_SBB, norm_fn=norm_fn, indice_key=f'sedlayer{idx+4}')
self.layers.append(conv)

# [236, 236, 11] -> [236, 236, 5] --> [236, 236, 2]
self.conv_out = spconv.SparseSequential(
spconv.SparseConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=0, bias=False, indice_key='spconv4'),
norm_fn(dim),
nn.ReLU(),
spconv.SparseConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=0, bias=False, indice_key='spconv5'),
norm_fn(dim),
nn.ReLU(),
)

self.num_point_features = dim

def forward(self, batch_dict):
voxel_features = batch_dict['voxel_features']
voxel_coords = batch_dict['voxel_coords']
batch_size = batch_dict['batch_size']

x = spconv.SparseConvTensor(
features=voxel_features,
indices=voxel_coords.int(),
spatial_shape=self.sparse_shape,
batch_size=batch_size
)

x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
for conv in self.layers:
x = conv(x)
x = self.conv_out(x)

batch_dict.update({
'encoded_spconv_tensor': x,
'encoded_spconv_tensor_stride': 8
})
return batch_dict
2 changes: 1 addition & 1 deletion pcdet/models/detectors/detector3d_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def build_backbone_2d(self, model_info_dict):

backbone_2d_module = backbones_2d.__all__[self.model_cfg.BACKBONE_2D.NAME](
model_cfg=self.model_cfg.BACKBONE_2D,
input_channels=model_info_dict.get('num_bev_features', None)
input_channels=model_info_dict.get('num_bev_features', model_info_dict['num_point_features'])
)
model_info_dict['module_list'].append(backbone_2d_module)
model_info_dict['num_bev_features'] = backbone_2d_module.num_bev_features
Expand Down
2 changes: 1 addition & 1 deletion pcdet/models/model_utils/model_nms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def class_specific_nms(box_scores, box_preds, box_labels, nms_config, score_thre
curr_boxes_for_nms = cur_box_preds

keep_idx, _ = getattr(iou3d_nms_utils, 'nms_gpu')(
curr_boxes_for_nms, curr_box_scores_nms,
curr_boxes_for_nms[:, :7], curr_box_scores_nms,
thresh=nms_config.NMS_THRESH[k],
pre_maxsize=nms_config.NMS_PRE_MAXSIZE[k],
post_max_size=nms_config.NMS_POST_MAXSIZE[k]
Expand Down
Loading

0 comments on commit ed844ed

Please sign in to comment.