-
Notifications
You must be signed in to change notification settings - Fork 0
/
transition_up.py
46 lines (39 loc) · 1.58 KB
/
transition_up.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
PointTransformer Transition Up blocks
"""
import mindspore.nn as nn
import mindspore.ops as ops
from pointnet2_fp import PointNetFeaturePropagation
class TransitionUp(nn.Cell):
"""
TransitionUp Block
"""
def __init__(self, dim1, dim2, dim_out):
class SwapAxes(nn.Cell):
def __init__(self):
super(SwapAxes, self).__init__()
self.random = ops.Transpose()
def construct(self, x):
new_features = self.random(x, (0, 2, 1))
return new_features
super(TransitionUp, self).__init__()
self.random = ops.Transpose()
self.f1 = nn.Dense(dim1, dim_out)
self.Swap = SwapAxes()
self.BN = nn.BatchNorm2d(dim_out, momentum=0.1, affine=True)
self.relu = nn.ReLU()
self.f2 = nn.Dense(dim2, dim_out)
self.fp = PointNetFeaturePropagation(-1, [])
self.transpose = ops.Transpose()
#self.feats1 = self.relu(ops.Squeeze(-1)(self.BN(ops.ExpandDims()(self.Swap(self.f1()), -1))))
#self.feats2 = self.relu(ops.Squeeze(-1)(self.BN(ops.ExpandDims()(self.Swap(self.f2()), -1))))
def construct(self, xyz1, point1, xyz2, point2):
"""
TransitionUp Block
"""
#feats1 = self.feats1(point1)
#feats2 = self.feats2(point2)
feats1 = self.relu(ops.Squeeze(-1)(self.BN(ops.ExpandDims()(self.Swap(self.f1(point1)), -1))))
feats2 = self.relu(ops.Squeeze(-1)(self.BN(ops.ExpandDims()(self.Swap(self.f2(point2)), -1))))
mind = self.fp(xyz2, xyz1, None, feats1)
return feats2 + mind