diff --git a/vajra/core/exporter.py b/vajra/core/exporter.py index 28c1b15..1cb156c 100644 --- a/vajra/core/exporter.py +++ b/vajra/core/exporter.py @@ -18,7 +18,7 @@ from vajra.dataset.dataset import VajraDetDataset from vajra.dataset.build import build_dataloader from vajra.dataset.utils import check_det_dataset, check_cls_dataset, check_class_names, default_class_names -from vajra.nn.modules import VajraMerudandaBhag1, VajraMerudandaBhag4, AttentionBottleneckV2 +from vajra.nn.modules import VajraMerudandaBhag1, VajraMerudandaBhag4, VajraMerudandaBhag7, AttentionBottleneckV2 from vajra.nn.head import Detection from vajra.nn.vajra import DetectionModel, SegmentationModel, VajraWorld from vajra.utils import ( @@ -168,7 +168,7 @@ def __call__(self, model=None): module.dynamic = self.args.dynamic module.export = True module.format = self.args.format - elif isinstance(module, (VajraMerudandaBhag4, AttentionBottleneckV2)) and not is_tf_format: + elif isinstance(module, (VajraMerudandaBhag4, AttentionBottleneckV2, VajraMerudandaBhag7)) and not is_tf_format: module.forward = module.forward_split y = None diff --git a/vajra/nn/head.py b/vajra/nn/head.py index 0445811..3b3e468 100644 --- a/vajra/nn/head.py +++ b/vajra/nn/head.py @@ -63,7 +63,7 @@ def __init__(self, num_classes=80, in_channels=[]) -> None: for ch in in_channels ) self.branch_cls = nn.ModuleList( - nn.Sequential(nn.Sequential(DepthwiseConvBNAct(ch, ch, 1, 3), ConvBNAct(ch, c3, 1, 1)),#ConvBNAct(ch, c3, 1, 3), + nn.Sequential(nn.Sequential(DepthwiseConvBNAct(ch, ch, 1, 3), ConvBNAct(ch, c3, 1, 1)), nn.Sequential(DepthwiseConvBNAct(c3, c3, 1, 3), ConvBNAct(c3, c3, 1, 1)), nn.Conv2d(c3, self.num_classes, 1)) for ch in in_channels @@ -158,11 +158,11 @@ def __init__(self, num_classes=80, num_masks=32, num_protos=256, in_channels=[]) self.num_masks = num_masks self.num_protos = num_protos self.proto = ProtoMaskModule(in_channels[0], self.num_protos, self.num_masks) - self.detection = Detection.forward c4 = max(in_channels[0] // 4, self.num_masks) self.branch_seg = nn.ModuleList( nn.Sequential( ConvBNAct(ch, c4, kernel_size=3, stride=1), + ConvBNAct(c4, c4, kernel_size=3, stride=1), nn.Conv2d(c4, self.num_masks, 1) ) for ch in in_channels @@ -174,7 +174,7 @@ def forward(self, x): mask_coefficients = torch.cat([self.branch_seg[i](x[i]).view(batch_size, self.num_masks, -1) for i in range(self.num_det_layers)], 2) - x = self.detection(self, x) + x = Detection.forward(self, x) if self.training: return x, mask_coefficients, proto_masks return (torch.cat([x, mask_coefficients], 1), proto_masks) if self.export else (torch.cat([x[0], mask_coefficients], 1), (x[1], mask_coefficients, proto_masks)) @@ -192,6 +192,7 @@ def __init__(self, num_classes=80, keypoint_shape=(17, 3), in_channels=[]) -> No self.branch_pose_detect = nn.ModuleList( nn.Sequential( ConvBNAct(ch, c4, kernel_size=3, stride=1), + ConvBNAct(c4, c4, kernel_size=3, stride=1), nn.Conv2d(c4, self.num_keypoints, 1) ) for ch in in_channels @@ -200,8 +201,18 @@ def __init__(self, num_classes=80, keypoint_shape=(17, 3), in_channels=[]) -> No def decode_keypoints(self, batch_size, keypoints): ndim = self.keypoint_shape[1] if self.export: - y = keypoints.view(batch_size, *self.keypoint_shape, -1) - a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides + if self.format in { + "tflite", + "edgetpu", + }: + y = keypoints.view(batch_size, *self.keypoint_shape, -1) + grid_h, grid_w = self.shape[2], self.shape[3] + grid_size = torch.tensor([grid_w, grid_h], device=y.device).reshape(1, 2, 1) + norm = self.strides / (self.stride[0] * grid_size) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * norm + else: + y = keypoints.view(batch_size, *self.keypoint_shape, -1) + a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides if ndim == 3: a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2) return a.view(batch_size, self.num_keypoints, -1) @@ -217,7 +228,7 @@ def forward(self, x): batch_size = x[0].shape[0] # batch_size keypoint = torch.cat([self.branch_pose_detect[i](x[i]).view(batch_size, self.num_keypoints, -1) for i in range(self.num_det_layers)], -1) # (batch_size, 17*3, h*w) - x = self.detection(self, x) + x = Detection.forward(self, x) if self.training: return x, keypoint pred_keypoints = self.decode_keypoints(batch_size, keypoint) @@ -233,7 +244,7 @@ def __init__(self, num_classes=80, embed_dim=512, with_bn=False, in_channels=[]) self.with_bn = with_bn self.in_channels = in_channels c3 = max(in_channels[0], min(self.num_classes, 100)) - self.branch3 = nn.ModuleList(nn.Sequential(ConvBNAct(x, c3, 3), nn.Conv2d(c3, embed_dim, 1)) for x in in_channels) + self.branch3 = nn.ModuleList(nn.Sequential(ConvBNAct(ch, c3, stride=1, kernel_size=3), ConvBNAct(c3, c3, kernel_size=3, stride=1), nn.Conv2d(c3, embed_dim, 1)) for ch in in_channels) self.branch4 = nn.ModuleList(BNContrastiveHead(embed_dim) if with_bn else ContrastiveHead() for _ in in_channels) def forward(self, x, text): @@ -260,13 +271,18 @@ def forward(self, x, text): grid_w = shape[3] grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1) norm = self.strides / (self.stride[0] * grid_size) - dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0) * norm[:, :2], xywh=True, dim=1) + dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0) * norm[:, :2]) else: - dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides + dist_box = self.decode_bboxes(self.distributed_focal_loss(box), self.anchors.unsqueeze(0)) * self.strides y = torch.cat((dist_box, cls.sigmoid()), 1) return y if self.export else (y, x) + + def bias_init(self): + detection_module = self + for branch_a, branch_b, stride in zip(detection_module.branch_det, detection_module.branch_cls, detection_module.stride): + branch_a[-1].bias.data[:] = 1.0 def get_module_info(self): return f"WorldDetection", f"[{self.num_classes}, {self.embed_dim}, {self.with_bn}, {self.in_channels}]" @@ -276,11 +292,11 @@ class OBBDetection(Detection): def __init__(self, num_classes=80, num_extra_params=1, in_channels=[]) -> None: super().__init__(num_classes, in_channels) self.num_extra = num_extra_params - self.detect = Detection.forward c4 = max(in_channels[0] // 4, self.num_extra) self.oriented_branch = nn.ModuleList( nn.Sequential( - ConvBNAct(ch, c4, 3), + ConvBNAct(ch, c4, kernel_size=3, stride=1), + ConvBNAct(c4, c4, kernel_size=3, stride=1), nn.Conv2d(c4, self.num_extra, 1) ) for ch in in_channels @@ -293,7 +309,7 @@ def forward(self, x): if not self.training: self.angle = angle - x = self.detect(self, x) + x = Detection.forward(self, x) if self.training: return x, angle diff --git a/vajra/nn/modules.py b/vajra/nn/modules.py index c7c887e..7b27c55 100644 --- a/vajra/nn/modules.py +++ b/vajra/nn/modules.py @@ -451,16 +451,44 @@ def get_module_info(self): return "BasicBlock", f"[{self.inplanes}, {self.planes}, {self.stride}, {self.downsample}, {self.groups}, {self.base_width}, {self.dilation}]" class BottleneckV2(nn.Module): - def __init__(self, in_c, out_c, shortcut=True, kernel_size=(3,3), expansion_ratio=0.5, groups=1): # ch_in, ch_out, shortcut, groups, kernels, expand + """Standard bottleneck""" + + def __init__(self, in_c, out_c, kernel_size=(3,3), expansion_ratio=0.5, groups=1, act="silu"): # ch_in, ch_out, shortcut, groups, kernels, expand super().__init__() hidden_c = int(out_c * expansion_ratio) # hidden channels - self.conv1 = ConvBNAct(in_c, hidden_c, kernel_size=kernel_size[0], stride=1) - self.conv2 = ConvBNAct(hidden_c, out_c, kernel_size=kernel_size[1], stride=1, groups=groups, act=None) - self.act = nn.SiLU() - self.add = shortcut and in_c == out_c + assert in_c == out_c, "in channels and out channels must be equal for residual bottleneck" + self.in_c = in_c + self.out_c = out_c + self.kernel_size=kernel_size + self.expansion_ratio=expansion_ratio + self.groups=groups + self.act = act + self.conv1 = ConvBNAct(in_c, hidden_c, kernel_size=kernel_size[0], stride=1, act=act) + self.conv2 = ConvBNAct(hidden_c, out_c, kernel_size=kernel_size[1], stride=1, groups=groups, act=act) def forward(self, x): - return self.act(x + self.conv2(self.conv1(x))) if self.add else self.act(self.conv2(self.conv1(x))) + return x + self.conv2(self.conv1(x)) + + def get_module_info(self): + return "BottleneckV2", f"[{self.in_c}, {self.out_c}, {self.kernel_size}, {self.expansion_ratio}, {self.groups}, {act_table[self.act]}]" + +class MSBottleneck(nn.Module): + def __init__(self, in_c, out_c, kernel_size=3): + super().__init__() + hidden_c = int(out_c * 0.5) + assert in_c == out_c, "in channels and out channels must be equal for MSBottleneck" + self.in_c = in_c + self.out_c = out_c + self.kernel_size = kernel_size + self.conv1 = ConvBNAct(in_c, hidden_c, 1, 1) + self.conv2 = DepthwiseConvBNAct(hidden_c, hidden_c, 1, kernel_size=kernel_size) + self.conv3 = ConvBNAct(hidden_c, out_c, 1, 1) + + def forward(self, x): + return x + self.conv3(self.conv2(self.conv1(x))) + + def get_module_info(self): + return "MSBottleneck", f"[{self.in_c}, {self.out_c}, {self.kernel_size}]" class BottleneckV3(nn.Module): @@ -500,9 +528,9 @@ def forward(self, x): return x + self.conv2(self.dwconv(self.conv1(x))) if self.add else self.conv2(self.dwconv(self.conv1(x))) class RepVGGDW(nn.Module): - def __init__(self, dim) -> None: + def __init__(self, dim, kernel_size=7) -> None: super().__init__() - self.conv = DepthwiseConvBNAct(dim, dim, 1, 7, act=None) + self.conv = DepthwiseConvBNAct(dim, dim, 1, kernel_size=kernel_size, act=None) self.conv1 = DepthwiseConvBNAct(dim, dim, 1, 3, act=None) self.dim = dim self.act = nn.SiLU() @@ -550,7 +578,7 @@ def forward_fuse(self, x): class MerudandaDW(nn.Module): - def __init__(self, in_c, out_c, shortcut=True, expansion_ratio=0.5, use_rep_vgg_dw=False): + def __init__(self, in_c, out_c, shortcut=True, expansion_ratio=0.5, use_rep_vgg_dw=False, kernel_size=3, stride=1): super().__init__() hidden_c = int(out_c * expansion_ratio) self.in_c = in_c @@ -558,11 +586,12 @@ def __init__(self, in_c, out_c, shortcut=True, expansion_ratio=0.5, use_rep_vgg_ self.shortcut=shortcut self.expansion_ratio = expansion_ratio self.use_rep_vgg_dw = use_rep_vgg_dw - self.add = shortcut and in_c == out_c + self.kernel_size = kernel_size + self.add = shortcut and in_c == out_c and stride == 1 self.block = nn.Sequential( DepthwiseConvBNAct(in_c, in_c, kernel_size=3), ConvBNAct(in_c, 2 * hidden_c, 1, 1), - DepthwiseConvBNAct(2 * hidden_c, 2 * hidden_c, 1, 3) if not use_rep_vgg_dw else RepVGGDW(2 * hidden_c), + DepthwiseConvBNAct(2 * hidden_c, 2 * hidden_c, stride=stride, kernel_size=kernel_size) if not use_rep_vgg_dw else RepVGGDW(2 * hidden_c, kernel_size=kernel_size), ConvBNAct(2 * hidden_c, out_c, 1, 1), DepthwiseConvBNAct(out_c, out_c, 1, 3), ) @@ -572,8 +601,243 @@ def forward(self, x): return x + y if self.add else y def get_module_info(self): - return "MerudandaDW", f"[{self.in_c}, {self.out_c}, {self.shortcut}, {self.expansion_ratio}, {self.use_rep_vgg_dw}]" + return "MerudandaDW", f"[{self.in_c}, {self.out_c}, {self.shortcut}, {self.expansion_ratio}, {self.use_rep_vgg_dw}, {self.kernel_size}]" + +class VajraMerudandaMS(nn.Module): + def __init__(self, in_c, out_c, kernel_size=(1, 3, 3), num_blocks=2, expansion_ratio=0.5): + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.kernel_size=kernel_size + self.num_blocks = num_blocks + self.expansion_ratio = expansion_ratio + self.hidden_c = int((out_c * expansion_ratio) * len(kernel_size)) + self.block_c = self.hidden_c // len(kernel_size) + self.conv1 = ConvBNAct(in_c, self.hidden_c, 1, 1) + self.blocks = [] + for i in range(len(kernel_size)): + if i == 0: + self.blocks.append(nn.Identity()) + continue + block = nn.Sequential(*[MerudandaDW(self.block_c, self.block_c, True, 0.5, use_rep_vgg_dw=True, kernel_size=kernel_size[i]) for _ in range(num_blocks)]) + self.blocks.append(block) + self.blocks = nn.ModuleList(self.blocks) + self.conv2 = ConvBNAct(self.hidden_c, out_c, 1, 1) + + def forward(self, x): + conv1 = self.conv1(x) + fms = [] + + for i, block in enumerate(self.blocks): + fm = conv1[:,i*self.block_c:(i+1)*self.block_c,...] + if i > 1: + fm = fm + fms[i-1] + fm = block(fm) + fms.append(fm) + block_out = torch.cat(fms, 1) + out = self.conv2(block_out) + return out + + def get_module_info(self): + return "VajraMerudandaMS", f"[{self.in_c}, {self.out_c}, {self.kernel_size}, {self.num_blocks}, {self.expansion_ratio}]" + +class VajraMerudandaV2(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, expansion_ratio=0.5, inner_block=False): + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.num_blocks = num_blocks + self.expansion_ratio = expansion_ratio + self.hidden_c = int((out_c * expansion_ratio) * 2) + self.block_c = self.hidden_c // 2 + self.conv1 = ConvBNAct(in_c // 2, self.hidden_c, 1, 1) + block = InnerBlock if inner_block else Bottleneck + self.blocks = nn.ModuleList( + block(self.block_c, self.block_c, 2, True, 1, 0.5) for _ in range(num_blocks) + ) if block == InnerBlock else nn.ModuleList( + block(self.block_c, self.block_c, shortcut=True, expansion_ratio=0.5) for _ in range(num_blocks) + ) + self.conv2 = ConvBNAct(in_c // 2 + (num_blocks + 2) * self.block_c, out_c, 1, 1) + self.add = in_c == out_c + + def forward(self, x): + in_1, in_2 = x.chunk(2, 1) + in_2 = in_2 + in_1 + conv1 = self.conv1(in_2) + fm1, fm2 = conv1.chunk(2, 1) + #fm1 = conv1[:, 0:self.block_c, ...] + #fm2 = conv1[:, self.block_c:2*self.block_c, ...] + fm2 = fm2 + fm1 + fms = [in_1, fm1, fm2] + fms.extend(block(fms[-1]) for block in self.blocks) + block_out = torch.cat(fms, 1) + out = self.conv2(block_out) + return out + x if self.add else out + + def get_module_info(self): + return "VajraMerudandaV2", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.expansion_ratio}]" + +class VajraMerudandaV2Bhag1(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, inner_block=False, kernel_size=3): + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.num_blocks = num_blocks + self.block_c = self.in_c // 2 + self.inner_block = inner_block + self.kernel_size = kernel_size + block = InnerBlock if inner_block else Bottleneck + self.conv1 = ConvBNAct(in_c, in_c, 1, 1) + self.blocks = nn.ModuleList( + block(self.block_c, self.block_c, 2, True, 1, 0.5) for _ in range(num_blocks) + ) if block == InnerBlock else nn.ModuleList( + block(self.block_c, self.block_c, shortcut=True, expansion_ratio=0.5) for _ in range(num_blocks) + ) + self.dwconv = nn.Sequential(DepthwiseConvBNAct(self.block_c, self.block_c, 1, kernel_size=kernel_size), ConvBNAct(self.block_c, self.block_c, 1, 1)) + self.conv2 = ConvBNAct(in_c + num_blocks * self.block_c, out_c, 1, 1) + self.add = in_c == out_c + + def forward(self, x): + conv1 = self.conv1(x) + fm1, fm2 = conv1.chunk(2, 1) + fm1 = self.dwconv(fm1) + fm2 = fm2 + fm1 + fms = [fm1, fm2] + fms.extend(block(fms[-1]) for block in self.blocks) + out = self.conv2(torch.cat(fms, 1)) + return out + x if self.add else out + + def get_module_info(self): + return "VajraMerudandaV2Bhag1", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.inner_block}, {self.kernel_size}]" + +class VajraGrivaV2Bhag1(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, inner_block=False): + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.num_blocks = num_blocks + self.inner_block = inner_block + block = InnerBlock if inner_block else Bottleneck + self.block_c = in_c // 4 + self.blocks = nn.ModuleList( + block(self.block_c, self.block_c, 2, True, 1, 0.5) for _ in range(num_blocks) + ) if block == InnerBlock else nn.ModuleList( + block(self.block_c, self.block_c, shortcut=True, expansion_ratio=0.5) for _ in range(num_blocks) + ) + self.conv = ConvBNAct(in_c + num_blocks * self.block_c, out_c, 1, 1) + self.dwconv1 = nn.Sequential(ConvBNAct(in_c // 4, in_c // 4, 1, 1), DepthwiseConvBNAct(in_c // 4, in_c // 4, 1, 3), ConvBNAct(in_c // 4, in_c // 4, 1, 1)) + self.dwconv2 = nn.Sequential(ConvBNAct(in_c // 4, in_c // 4, 1, 1), DepthwiseConvBNAct(in_c // 4, in_c // 4, 1, 3), ConvBNAct(in_c // 4, in_c // 4, 1, 1)) + + def forward(self, inputs): + fm1, fm2, fm3, fm4 = inputs.chunk(4, 1) + fm2 = fm2 + fm1 + fm2 = self.dwconv1(fm2) + fm3 = fm3 + fm2 + fm3 = self.dwconv2(fm3) + fm4 = fm4 + fm3 + fms = [fm1, fm2, fm3, fm4] + fms.extend(block(fms[-1]) for block in self.blocks) + out = self.conv(torch.cat(fms, 1)) + return out + + def get_module_info(self): + return "VajraGrivaV2Bhag1", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.inner_block}]" + +class VajraGrivaV2Bhag2(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, inner_block=False): + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.num_blocks = num_blocks + self.inner_block = inner_block + block = InnerBlock if inner_block else Bottleneck + self.block_c = out_c // 2 + self.hidden_c = in_c - 3 * self.block_c + self.blocks = nn.ModuleList( + block(self.block_c, self.block_c, 2, True, 1, 0.5) for _ in range(num_blocks) + ) if block == InnerBlock else nn.ModuleList( + block(self.block_c, self.block_c, shortcut=True, expansion_ratio=0.5) for _ in range(num_blocks) + ) + self.dwconvs = nn.ModuleList(nn.Sequential(ConvBNAct(self.block_c, self.block_c, 1, 1), DepthwiseConvBNAct(self.block_c, self.block_c, 1, 3), ConvBNAct(self.block_c, self.block_c, 1, 1)) for _ in range(2)) + self.conv = ConvBNAct(in_c + num_blocks * self.block_c, out_c, 1, 1) + + def forward(self, inputs): + fms = list(inputs.split(self.block_c, dim=1)) + for i in range(len(fms)): + if i >= 1: + fms[i] = fms[i] + fms[i-1] + #if i < len(fms) - 1: + #fms[i] = self.dwconvs[i-1](fms[i]) + if i == len(fms) - 3: + fms[i] = self.dwconvs[0](fms[i]) + + if i == len(fms) - 2: + fms[i] = self.dwconvs[1](fms[i]) + + fms.extend(block(fms[-1]) for block in self.blocks) + out = self.conv(torch.cat(fms, 1)) + return out + + def get_module_info(self): + return "VajraGrivaV2Bhag2", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.inner_block}]" + +class AttentionBottleneckV5(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, expansion_ratio=0.5) -> None: + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.hidden_c = int((out_c * expansion_ratio) * 2) + self.block_c = self.hidden_c // 2 + self.conv1 = ConvBNAct(in_c, self.hidden_c, 1, 1) + self.num_blocks = num_blocks + self.expansion_ratio = expansion_ratio + self.attn = nn.ModuleList( + [nn.Identity()] + + [nn.Sequential(*[AttentionBlock(self.block_c, self.block_c, num_heads=self.block_c // 64) for _ in range(num_blocks)])] + ) + self.conv2 = ConvBNAct(self.hidden_c, out_c, 1, 1) + + def forward(self, x): + conv1 = self.conv1(x) + fms = [] + + for i, block in enumerate(self.attn): + fm = conv1[:,i*self.block_c:(i+1)*self.block_c,...] + if i > 1: + fm = fm + fms[i-1] + fm = block(fm) + fms.append(fm) + block_out = torch.cat(fms, 1) + out = self.conv2(block_out) + return out + + def get_module_info(self): + return f"AttentionBottleneckV5", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.expansion_ratio}]" + +class AttentionBottleneckV6(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2) -> None: + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.num_blocks = num_blocks + assert in_c == out_c, "For AttentionBottleneckV6 in channels should be equal to out channels" + self.block_c = out_c // 2 + self.attn = nn.ModuleList(AttentionBlock(self.block_c, self.block_c, num_heads=self.block_c // 64) for _ in range(num_blocks)) + self.dwconv = nn.Sequential(ConvBNAct(self.block_c, self.block_c, 1, 1), DepthwiseConvBNAct(self.block_c, self.block_c, 1, 3), ConvBNAct(self.block_c, self.block_c, 1, 1)) + self.conv = ConvBNAct(in_c + num_blocks * self.block_c, out_c, 1, 1) + + def forward(self, x): + fm1, fm2 = x.chunk(2, 1) + fm1 = self.dwconv(fm1) + fm2 = fm2 + fm1 + fms = [fm1, fm2] + fms.extend(attn_block(fms[-1]) for attn_block in self.attn) + out = self.conv(torch.cat(fms, 1)) + return out + x + def get_module_info(self): + return "AttentionBottleneckV6", f"[{self.in_c}, {self.out_c}, {self.num_blocks}]" + class ADown(nn.Module): def __init__(self, in_c, out_c): super().__init__() @@ -830,36 +1094,33 @@ def get_module_info(self): return f"VajraGrivaBhag4", f"[{self.out_c}, {self.num_blocks}, {self.shortcut}, {self.kernel_size}, {self.expansion_ratio}, {self.bhag1}, {self.use_cbam}]" class VajraMerudandaBhag5(nn.Module): - def __init__(self, in_c, out_c, shortcut=False, num_heads=2) -> None: + def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, expansion_ratio=0.5, inner_block=False, num_bottleneck_blocks=2, use_cbam=False) -> None: super().__init__() + block = InnerBlockV2 if inner_block else BottleneckV2 + hidden_c = int(out_c * expansion_ratio) self.in_c = in_c self.out_c = out_c - self.shortcut = shortcut - self.num_heads = num_heads + self.expansion_ratio = expansion_ratio + self.num_blocks=num_blocks + self.shortcut=shortcut + self.inner_block = inner_block + self.use_cbam = use_cbam + self.kernel_size=kernel_size + self.conv1 = ConvBNAct(in_c, hidden_c, 1, kernel_size) + self.bottleneck_blocks = nn.ModuleList(block(hidden_c, hidden_c, expansion_ratio=0.5) for _ in range(num_blocks)) if block == BottleneckV2 else nn.ModuleList(block(hidden_c, hidden_c, num_bottleneck_blocks, kernel_size=1) for _ in range(num_blocks)) + self.conv2 = ConvBNAct((num_blocks + 1) * hidden_c, out_c, kernel_size=1, stride=1) + #self.cbam = CBAM(out_c) if self.use_cbam else nn.Identity() + #self.add = shortcut and in_c == out_c - hidden_c = int(out_c * 0.5) - block = Bottleneck - self.conv1 = ConvBNAct(in_c, 2 * hidden_c, 1, 3) - self.attn = Attention(dim=hidden_c, num_heads=num_heads) - self.branch_a_bottleneck_blocks = nn.Sequential(*(block(hidden_c, hidden_c, shortcut=shortcut, expansion_ratio=1.0) for _ in range(1))) - self.branch_b_bottleneck_blocks = nn.ModuleList(block(hidden_c, hidden_c, shortcut=shortcut, expansion_ratio=1.0) for _ in range(2)) - self.conv2 = ConvBNAct(in_c + 5 * hidden_c, out_c, 1, 1) - self.cbam = CBAM(out_c) - self.add = shortcut and in_c == out_c - def forward(self, x): - a, b = self.conv1(x).chunk(2, 1) - attn = self.attn(a) - attn = a + attn - branch_a_bottleneck = self.branch_a_bottleneck_blocks(attn) - branch_b_bottleneck = [branch_b_bottleneck_block(b) for branch_b_bottleneck_block in self.branch_b_bottleneck_blocks] - branch_b_bottleneck = torch.cat(branch_b_bottleneck, 1) - conv2 = self.conv2(torch.cat((x, a, b, branch_a_bottleneck, branch_b_bottleneck), 1)) - cbam = self.cbam(conv2) - return x + cbam if self.add else conv2 + cbam + y = [self.conv1(x)] + y.extend(bottleneck(y[-1]) for bottleneck in self.bottleneck_blocks) + y = self.conv2(torch.cat(y, 1)) + #cbam = self.cbam(y) + return y def get_module_info(self): - return f"VajraMerudandaBhag5", f"[{self.in_c}, {self.out_c}, {self.shortcut}, {self.num_heads}]" + return f"VajraMerudandaBhag5", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.shortcut}, {self.kernel_size}, {self.expansion_ratio}, {self.inner_block}, {self.use_cbam}]" class VajraMerudandaBhag6(nn.Module): def __init__(self, in_c, out_c, num_blocks=2, shortcut=False, expansion_ratio = 0.5) -> None: @@ -891,30 +1152,42 @@ def get_module_info(self): return f"VajraMerudandaBhag6", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.shortcut}, {self.expansion_ratio}]" class VajraMerudandaBhag7(nn.Module): - def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, bottleneck_dwcib=False, expansion_ratio=0.5, dw=False) -> None: + def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, expansion_ratio=0.5, inner_block=False, num_bottleneck_blocks=2, use_cbam=False) -> None: super().__init__() - block = MerudandaDW if bottleneck_dwcib else Bottleneck + block = InnerBlock if inner_block else Bottleneck hidden_c = int(out_c * expansion_ratio) self.in_c = in_c self.out_c = out_c self.expansion_ratio = expansion_ratio self.num_blocks=num_blocks self.shortcut=shortcut + self.inner_block = inner_block + self.use_cbam = use_cbam self.kernel_size=kernel_size - self.bottleneck_dwcib = bottleneck_dwcib - self.dwconv = dw - self.conv1 = ConvBNAct(in_c, hidden_c, 1, kernel_size) if not dw else nn.Sequential(DepthwiseConvBNAct(in_c, in_c, 1, 3), ConvBNAct(in_c, hidden_c, 1, 1)) - self.bottleneck_blocks = nn.Sequential(*(block(hidden_c, hidden_c, shortcut=shortcut, expansion_ratio=1.0) for _ in range(num_blocks))) - self.conv2 = ConvBNAct(2 * hidden_c, out_c, kernel_size=1, stride=1) + self.hidden_c = hidden_c + self.conv1 = ConvBNAct(in_c, 2 * hidden_c, 1, kernel_size) + self.bottleneck_blocks = nn.ModuleList(block(hidden_c, hidden_c, shortcut=shortcut, expansion_ratio=0.5) for _ in range(num_blocks)) if block == Bottleneck else nn.ModuleList(block(hidden_c, hidden_c, num_bottleneck_blocks, shortcut=True, kernel_size=1, expansion_ratio=0.5) for _ in range(num_blocks)) + self.conv2 = ConvBNAct((num_blocks + 2) * hidden_c, out_c, kernel_size=1, stride=1) self.add = shortcut and in_c == out_c def forward(self, x): - conv = self.conv1(x) - y = self.conv2(torch.cat((self.bottleneck_blocks(conv), conv), 1)) - return y + x if self.add else y + #a, b = x.chunk(2, 1) + #y = [a, self.conv1(b)] + y = list(self.conv1(x).chunk(2, 1)) + y.extend(bottleneck(y[-1]) for bottleneck in self.bottleneck_blocks) + y = self.conv2(torch.cat(y, 1)) + return y + + def forward_split(self, x): + #a, b = x.split((self.in_c // 2, self.in_c // 2), 1) + #y = [a, self.conv1(b)] + y = list(self.conv1(x).split((self.hidden_c, self.hidden_c), 1)) + y.extend(bottleneck(y[-1]) for bottleneck in self.bottleneck_blocks) + y = self.conv2(torch.cat(y, 1)) + return y def get_module_info(self): - return f"VajraMerudandaBhag7", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.shortcut}, {self.kernel_size}, {self.bottleneck_dwcib}, {self.expansion_ratio}, {self.dwconv}]" + return f"VajraMerudandaBhag7", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.shortcut}, {self.kernel_size}, {self.expansion_ratio}, {self.inner_block}, {self.use_cbam}]" class VajraGrivaBhag1(nn.Module): def __init__(self, out_c, num_blocks=3, kernel_size=1, expansion_ratio=0.5, use_cbam=False, bottleneck_dw=False, use_rep_vgg_dw=False) -> None: @@ -957,6 +1230,25 @@ def forward(self, x): out = self.conv2(torch.cat((a, b), 1)) return x + out if self.add else out +class InnerBlockV2(nn.Module): + def __init__(self, in_c, out_c, num_blocks=1) -> None: + super().__init__() + assert in_c == out_c, "For InnerBlockV2 in channels should be equal to out channels" + hidden_c = int(out_c * 0.5) + self.conv1 = ConvBNAct(in_c, 2 * hidden_c, 1, 1) + self.dwconv = nn.Sequential(DepthwiseConvBNAct(hidden_c, hidden_c, 1, 3), ConvBNAct(hidden_c, hidden_c, 1, 1)) + self.conv2 = ConvBNAct(2 * hidden_c, out_c, kernel_size=1, stride=1) + self.bottleneck_blocks = nn.Sequential(*[BottleneckV2(hidden_c, hidden_c, expansion_ratio=1.0) for _ in range(num_blocks)]) + + def forward(self, x): + conv1 = self.conv1(x) + fm1, fm2 = conv1.chunk(2, 1) + fm1 = self.dwconv(fm1) + fm2 = fm2 + fm1 + fm2 = self.bottleneck_blocks(fm2) + out = self.conv2(torch.cat((fm1, fm2), 1)) + return x + out + class VajraMerudandaBhag2(nn.Module): def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, expansion_ratio=0.5, bhag1=False, use_cbam=False, bottleneck_dw=False) -> None: super().__init__() @@ -1007,7 +1299,7 @@ def get_module_info(self): return f"VajraGrivaBhag2", f"[{self.out_c}, {self.num_blocks}, {self.kernel_size}]" class VajraAttentionBlock(nn.Module): - def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, embed_channels=128, num_heads=1, guide_channels=512) -> None: + def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, embed_channels=128, num_heads=1, guide_channels=512, inner_block=False, use_cbam = False) -> None: super().__init__() block = Bottleneck hidden_c = int(out_c * 0.5) @@ -1019,16 +1311,18 @@ def __init__(self, in_c, out_c, num_blocks=3, shortcut=False, kernel_size=1, emb self.embed_channels = embed_channels self.num_heads = num_heads self.guide_channels = guide_channels - + self.inner_block = inner_block + self.use_cbam = use_cbam + block = InnerBlock if inner_block else Bottleneck self.conv1 = ConvBNAct(in_c, hidden_c, 1, kernel_size) - self.bottleneck_blocks = nn.ModuleList(block(hidden_c, hidden_c, shortcut=shortcut, expansion_ratio=1.0) for _ in range(num_blocks)) - self.conv2 = ConvBNAct((num_blocks + 1) * hidden_c, out_c, kernel_size=1, stride=1) + self.bottleneck_blocks = nn.ModuleList(block(hidden_c, hidden_c, 2, True, 1, 0.5) for _ in range(num_blocks)) if block == InnerBlock else nn.ModuleList(block(hidden_c, hidden_c, shortcut=shortcut, expansion_ratio=0.5) for _ in range(num_blocks)) + self.conv2 = ConvBNAct(in_c + (num_blocks + 2) * hidden_c, out_c, kernel_size=1, stride=1) self.add = shortcut and in_c == out_c - self.cbam = CBAM(out_c) + self.cbam = CBAM(out_c) if self.use_cbam else nn.Identity() self.attn = MaxSigmoidAttentionBlock(hidden_c, hidden_c, num_heads, embed_channels, guide_channels) def forward(self, x, guide): - y = [self.conv1(x)] + y = [x, self.conv1(x)] y.extend(bottleneck(y[-1]) for bottleneck in self.bottleneck_blocks) y.append(self.attn(y[-1], guide)) y = self.conv2(torch.cat(y, 1)) @@ -1036,7 +1330,7 @@ def forward(self, x, guide): return cbam + x if self.add else cbam + y def get_module_info(self): - return f"VajraAttentionBlock", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.shortcut}, {self.kernel_size}, {self.embed_channels}, {self.num_heads}, {self.guide_channels}]" + return f"VajraAttentionBlock", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.shortcut}, {self.kernel_size}, {self.embed_channels}, {self.num_heads}, {self.guide_channels}, {self.inner_block}, {self.use_cbam}]" class VajraV2BottleneckBlock(nn.Module): def __init__(self, in_c, out_c, num_blocks=3, num_bottleneck_blocks=2, shortcut=False, kernel_size=1, bottleneck_kernel_size=3) -> None: @@ -1272,12 +1566,12 @@ def forward(self, x): return self.conv3(torch.cat((self.bottleneck_blocks(self.conv1(x)), self.conv2(x)), 1)) class MaxSigmoidAttentionBlock(nn.Module): - def __init__(self, in_c, out_c, num_heads=1, embed_channels=512, guid_channels=512, scale=False) -> None: + def __init__(self, in_c, out_c, num_heads=1, embed_channels=512, guide_channels=512, scale=False) -> None: super().__init__() self.num_heads = num_heads self.head_channels = out_c // num_heads self.embed_conv = ConvBNAct(in_c, out_c, kernel_size=1, act=None) if in_c != embed_channels else None - self.guide_linear = nn.Linear(guid_channels, embed_channels) + self.guide_linear = nn.Linear(guide_channels, embed_channels) self.bias = nn.Parameter(torch.zeros(num_heads)) self.proj_conv = ConvBNAct(in_c, out_c, kernel_size=3, stride=1, act=None) self.scale = nn.Parameter(torch.ones(1, num_heads, 1, 1)) if scale else 1.0 @@ -1570,6 +1864,38 @@ def forward(self, inputs): def get_module_info(self): return f"SanlayanSPPF", f"[{self.in_c}, {self.out_c}, {self.stride}]" + +class SanlayanSPPFAttention(nn.Module): + def __init__(self, in_c, out_c, stride=2, num_blocks=2) -> None: + super().__init__() + self.in_c = in_c + self.out_c = out_c + self.stride = stride + self.num_blocks = num_blocks + self.branch_a_channels = in_c - self.out_c + self.hidden_c = out_c // 2 + self.out_c = out_c + self.sppf = SPPF(in_c=self.out_c, out_c=self.out_c, kernel_size=5) + self.attn = nn.Sequential(*(AttentionBlock(self.hidden_c, self.hidden_c, num_heads=self.hidden_c // 64) for _ in range(num_blocks))) + self.conv = ConvBNAct(in_c, out_c, 1, 1) + + def forward(self, inputs): + B, C, H, W = inputs[-1].shape + H = (H - 1) // self.stride + 1 + W = (W - 1) // self.stride + 1 + out = [F.interpolate(inp, size=(H, W), mode="nearest") for inp in inputs] + concatenated_in = torch.cat(out, dim=1) + fm1, fm2 = concatenated_in.split((self.branch_a_channels, self.out_c), 1) + fm2 = fm2 + fm1 + sppf = self.sppf(fm2) + fm3, fm4 = sppf.split(self.hidden_c, 1) + fm4 = fm4 + fm3 + attn = self.attn(fm4) + out = self.conv(torch.cat((fm1, fm3, attn), 1)) + return out + + def get_module_info(self): + return f"SanlayanSPPFAttention", f"[{self.in_c}, {self.out_c}, {self.stride}, {self.num_blocks}]" class MBConvEffNet(nn.Module): def __init__(self, in_c, out_c, stride=1, expansion_ratio=4, kernel_size=3) -> None: @@ -2550,6 +2876,51 @@ def forward_split(self, x): def get_module_info(self): return f"AttentionBottleneckV2", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.kernel_size}]" +class AttentionBottleneckV3(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, kernel_size=1) -> None: + super().__init__() + self.in_c = in_c + self.out_c = out_c + hidden_c = int(out_c * 0.5) + self.num_blocks = num_blocks + self.kernel_size=kernel_size + self.conv1 = ConvBNAct(in_c, hidden_c, 1, kernel_size=kernel_size) + self.attn = nn.ModuleList(AttentionBlock(hidden_c, hidden_c, num_heads=hidden_c // 64) for _ in range(num_blocks)) + self.conv2 = ConvBNAct((num_blocks + 1) * hidden_c, out_c, 1, 1) + + def forward(self, x): + y = [self.conv1(x)] + y.extend(attn_block(y[-1]) for attn_block in self.attn) + out = self.conv2(torch.cat(y, 1)) + return out + + def get_module_info(self): + return f"AttentionBottleneckV3", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.kernel_size}]" + +class AttentionBottleneckV4(nn.Module): + def __init__(self, in_c, out_c, num_blocks=2, kernel_size=1) -> None: + super().__init__() + self.in_c = in_c + self.out_c = out_c + hidden_c = int(out_c * 0.5) + assert in_c == out_c, "in channels and out channels must be equal for AttentionBottleneckV4" + self.num_blocks = num_blocks + self.kernel_size=kernel_size + self.conv1 = ConvBNAct(in_c, hidden_c, 1, kernel_size=kernel_size) + self.attn = nn.Sequential(*(AttentionBlock(hidden_c, hidden_c, num_heads=hidden_c // 64) for _ in range(num_blocks))) + self.conv2 = ConvBNAct(2 * hidden_c, out_c, 1, 1) + + def forward(self, x): + #y = [self.conv1(x)] + a = self.conv1(x) + b = self.attn(a) + #y.extend(attn_block(y[-1]) for attn_block in self.attn) + out = self.conv2(torch.cat((a, b), 1)) + return x + out + + def get_module_info(self): + return f"AttentionBottleneckV4", f"[{self.in_c}, {self.out_c}, {self.num_blocks}, {self.kernel_size}]" + class VajraV2Block(nn.Module): def __init__(self, dim) -> None: super().__init__() @@ -3308,16 +3679,10 @@ def __init__(self, in_c, c_mid = 256, out_c = 32) -> None: self.conv1 = ConvBNAct(in_c = in_c, out_c = c_mid, kernel_size=3) self.conv2 = ConvBNAct(in_c = c_mid, out_c = c_mid, kernel_size=3) self.conv3 = ConvBNAct(in_c = c_mid, out_c=out_c) + self.upsample = nn.ConvTranspose2d(c_mid, c_mid, 2, 2, 0, bias=True) def forward(self, x): - fm1 = self.conv1(x) - _, _, H, W = fm1.shape - H = 2 * H - W = 2 * W - up = F.interpolate(fm1, size=(H, W), mode="nearest") - fm2 = self.conv2(up) - out = self.conv3(fm2) - return out + return self.conv3(self.conv2(self.upsample(self.conv1(x)))) class UConv(nn.Module): def __init__(self, in_c, hidden_c = 256, out_c = 256): diff --git a/vajra/nn/vajra.py b/vajra/nn/vajra.py index 0b19d31..e80c492 100644 --- a/vajra/nn/vajra.py +++ b/vajra/nn/vajra.py @@ -116,11 +116,11 @@ def __init__(self, # Backbone self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) self.vajra_block1 = VajraMerudandaBhag3(channels_list[1], channels_list[2], num_repeats[0], 1, True, 0.25, False, inner_block_list[0]) # stride 4 - self.pool1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) + self.conv1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) self.vajra_block2 = VajraMerudandaBhag3(channels_list[2], channels_list[3], num_repeats[1], 1, True, 0.25, False, inner_block_list[1]) # stride 8 - self.pool2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) + self.conv2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) self.vajra_block3 = VajraMerudandaBhag3(channels_list[3], channels_list[4], num_repeats[2], 1, True, inner_block=inner_block_list[2]) # stride 16 - self.pool3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) + self.conv3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) self.vajra_block4 = VajraMerudandaBhag3(channels_list[4], channels_list[4], num_repeats[3], 1, True, inner_block=inner_block_list[3]) # stride 32 self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) @@ -146,14 +146,14 @@ def forward(self, x): stem = self.stem(x) vajra1 = self.vajra_block1(stem) - pool1 = self.pool1(vajra1) - vajra2 = self.vajra_block2(pool1) + conv1 = self.conv1(vajra1) + vajra2 = self.vajra_block2(conv1) - pool2 = self.pool2(vajra2) - vajra3 = self.vajra_block3(pool2) + conv2 = self.conv2(vajra2) + vajra3 = self.vajra_block3(conv2) - pool3 = self.pool3(vajra3) - vajra4 = self.vajra_block4(pool3) + conv3 = self.conv3(vajra3) + vajra4 = self.vajra_block4(conv3) pyramid_pool_backbone = self.pyramid_pool(vajra4) #self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) attn_block = self.attn_block(pyramid_pool_backbone) # Neck @@ -184,59 +184,71 @@ def __init__(self, channels_list = [64, 128, 256, 512, 1024, 256, 256, 256, 256, 256, 256, 256, 256], embed_channels=[256, 128, 256, 512], num_heads = [8, 4, 8, 16], - num_repeats=[3, 6, 6, 3, 3, 3, 3, 3]) -> None: + num_repeats=[2, 2, 2, 2, 2, 2, 2, 2], + inner_block_list = [False, False, True, True, False, False, False, True]) -> None: super().__init__() - self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], [1, 3, 5, -1], -1, [1, 5, 3, -1], -1, [8, 10, -1], -1, [10, 12, -1], -1, [12, 14, 16]] + self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, [5, -1], -1, -1, [3, -1], -1, -1, [11, -1], -1, -1, [9, -1], -1, [13, 16, 19]] # Backbone self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraMerudandaBhag1(channels_list[1], channels_list[1], num_repeats[0], True, 3, False, 0.25) # stride 4 - self.pool1 = MaxPool(kernel_size=2, stride=2) - self.vajra_block2 = VajraMerudandaBhag1(channels_list[1], channels_list[2], num_repeats[1], True, 3, False, 0.25) # stride 8 - self.pool2 = MaxPool(kernel_size=2, stride=2) - self.vajra_block3 = VajraMerudandaBhag1(channels_list[2], channels_list[3], num_repeats[2], True, 3) # stride 16 - self.pool3 = MaxPool(kernel_size=2, stride=2) - self.vajra_block4 = VajraMerudandaBhag1(channels_list[3], channels_list[4], num_repeats[3], True, 3) # stride 32 - self.pyramid_pool = Sanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=2, expansion_ratio=1.0) + self.vajra_block1 = VajraMerudandaBhag3(channels_list[1], channels_list[2], num_repeats[0], 1, True, 0.25, False, inner_block_list[0]) # stride 4 + self.conv1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) + self.vajra_block2 = VajraMerudandaBhag3(channels_list[2], channels_list[3], num_repeats[1], 1, True, 0.25, False, inner_block_list[1]) # stride 8 + self.conv2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) + self.vajra_block3 = VajraMerudandaBhag3(channels_list[3], channels_list[4], num_repeats[2], 1, True, inner_block=inner_block_list[2]) # stride 16 + self.conv3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) + self.vajra_block4 = VajraMerudandaBhag3(channels_list[4], channels_list[4], num_repeats[3], 1, True, inner_block=inner_block_list[3]) # stride 32 + self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) + self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) # Neck - self.fusion4cbam = ChatushtayaSanlayan(in_c=channels_list[1:5], out_c=channels_list[6], expansion_ratio=1.0) - self.vajra_neck1 = VajraAttentionBlock(channels_list[5], channels_list[6], num_repeats[4], False, 1, embed_channels=embed_channels[0], num_heads=num_heads[0]) + self.upsample1 = Upsample(2, "nearest") + self.concat1 = Concatenate(in_c=[channels_list[4], channels_list[4]], dimension=1) + self.vajra_neck1 = VajraAttentionBlock(channels_list[5], channels_list[6], num_repeats[4], False, 1, embed_channels=embed_channels[0], num_heads=num_heads[0], inner_block=inner_block_list[4]) - self.fusion4cbam2 = ChatushtayaSanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[6]], out_c=channels_list[8]) - self.vajra_neck2 = VajraAttentionBlock(channels_list[7], channels_list[8], num_repeats[5], False, 1, embed_channels=embed_channels[1], num_heads=num_heads[1]) + self.upsample2 = Upsample(2, "nearest") + self.concat2 = Concatenate(in_c=[channels_list[6], channels_list[3]], dimension=1) + self.vajra_neck2 = VajraAttentionBlock(channels_list[7], channels_list[8], num_repeats[5], False, 1, embed_channels=embed_channels[1], num_heads=num_heads[1], inner_block=inner_block_list[5]) - self.pyramid_pool_neck1 = Sanlayan(in_c=[channels_list[4], channels_list[6], channels_list[8]], out_c=channels_list[9], stride=2, use_cbam=False) - self.vajra_neck3 = VajraAttentionBlock(channels_list[9], channels_list[10], num_repeats[6], False, 1, embed_channels=embed_channels[2], num_heads=num_heads[2]) + self.neck_conv1 = ConvBNAct(channels_list[8], channels_list[9], 2, 3) + self.concat3 = Concatenate(in_c=[channels_list[6], channels_list[9]], dimension=1) + self.vajra_neck3 = VajraAttentionBlock(channels_list[9], channels_list[10], num_repeats[6], False, 1, embed_channels=embed_channels[2], num_heads=num_heads[2], inner_block=inner_block_list[6]) - self.pyramid_pool_neck2 = Sanlayan(in_c=[channels_list[6], channels_list[8], channels_list[10]], out_c=channels_list[11], stride=2, use_cbam=False) - self.vajra_neck4 = VajraAttentionBlock(channels_list[11], channels_list[12], num_repeats[7], False, 1, embed_channels=embed_channels[3], num_heads=num_heads[3]) + self.neck_conv2 = ConvBNAct(channels_list[10], channels_list[11], 2, 3) + self.concat4 = Concatenate(in_c=[channels_list[11], channels_list[4]], dimension=1) + self.vajra_neck4 = VajraAttentionBlock(channels_list[11], channels_list[12], num_repeats[7], False, 1, embed_channels=embed_channels[3], num_heads=num_heads[3], inner_block=inner_block_list[7]) def forward(self, x): # Backbone stem = self.stem(x) vajra1 = self.vajra_block1(stem) - pool1 = self.pool1(vajra1) - vajra2 = self.vajra_block2(pool1) - - pool2 = self.pool2(vajra2) - vajra3 = self.vajra_block3(pool2) + conv1 = self.conv1(vajra1) + vajra2 = self.vajra_block2(conv1) - pool3 = self.pool3(vajra3) - vajra4 = self.vajra_block4(pool3) - pyramid_pool_backbone = self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) + conv2 = self.conv2(vajra2) + vajra3 = self.vajra_block3(conv2) + conv3 = self.conv3(vajra3) + vajra4 = self.vajra_block4(conv3) + pyramid_pool_backbone = self.pyramid_pool(vajra4) #self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) + attn_block = self.attn_block(pyramid_pool_backbone) # Neck - fusion4 = self.fusion4cbam([vajra1, vajra2, vajra3, pyramid_pool_backbone]) - vajra_neck1 = self.vajra_neck1(fusion4) + #_, _, H3, W3 = vajra3.shape + neck_upsample1 = self.upsample1(attn_block) #F.interpolate(attn_block, size=(H3, W3), mode="nearest") + concat_neck1 = self.concat1([vajra3, neck_upsample1]) + vajra_neck1 = self.vajra_neck1(concat_neck1) - fusion4_2 = self.fusion4cbam2([vajra1, vajra3, vajra2, vajra_neck1]) - vajra_neck2 = self.vajra_neck2(fusion4_2) + #_, _, H2, W2 = vajra2.shape + neck_upsample2 = self.upsample2(vajra_neck1) #F.interpolate(vajra_neck1, size=(H2, W2), mode="nearest") + concat_neck2 = self.concat2([vajra2, neck_upsample2]) + vajra_neck2 = self.vajra_neck2(concat_neck2) - pyramid_pool_neck1 = self.pyramid_pool_neck1([pyramid_pool_backbone, vajra_neck1, vajra_neck2]) - vajra_neck3 = self.vajra_neck3(pyramid_pool_neck1) + neck_conv1 = self.neck_conv1(vajra_neck2) + concat_neck3 = self.concat3([vajra_neck1, neck_conv1]) + vajra_neck3 = self.vajra_neck3(concat_neck3) - pyramid_pool_neck2 = self.pyramid_pool_neck2([vajra_neck1, vajra_neck2, vajra_neck3]) - vajra_neck4 = self.vajra_neck4(pyramid_pool_neck2) + neck_conv2 = self.neck_conv2(vajra_neck3) + concat_neck4 = self.concat4([attn_block, neck_conv2]) + vajra_neck4 = self.vajra_neck4(concat_neck4) outputs = [vajra_neck2, vajra_neck3, vajra_neck4] return outputs @@ -245,33 +257,34 @@ class VajraV1CLSModel(nn.Module): def __init__(self, in_channels=3, channels_list=[64, 128, 256, 512, 1024], - num_repeats=[2, 2, 2, 2]) -> None: + num_repeats=[2, 2, 2, 2], + inner_block_list = [False, False, True, True]) -> None: super().__init__() self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraMerudandaBhag4(channels_list[1], channels_list[2], num_repeats[0], True, 1, 0.25, True) # stride 4 - self.pool1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) - self.vajra_block2 = VajraMerudandaBhag4(channels_list[2], channels_list[3], num_repeats[1], True, 1, 0.25, True) # stride 8 - self.pool2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) - self.vajra_block3 = VajraMerudandaBhag4(channels_list[3], channels_list[4], num_repeats[2], True, 1, bhag1=True) # stride 16 - self.pool3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) - self.vajra_block4 = VajraMerudandaBhag4(channels_list[4], channels_list[4], num_repeats[3], True, 1, bhag1=True) # stride 32 - self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) #Sanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=2, use_cbam=False, expansion_ratio=1.0) + self.vajra_block1 = VajraMerudandaBhag3(channels_list[1], channels_list[2], num_repeats[0], 1, True, 0.25, False, inner_block_list[0]) # stride 4 + self.conv1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) + self.vajra_block2 = VajraMerudandaBhag3(channels_list[2], channels_list[3], num_repeats[1], 1, True, 0.25, False, inner_block_list[1]) # stride 8 + self.conv2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) + self.vajra_block3 = VajraMerudandaBhag3(channels_list[3], channels_list[4], num_repeats[2], 1, True, inner_block=inner_block_list[2]) # stride 16 + self.conv3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) + self.vajra_block4 = VajraMerudandaBhag3(channels_list[4], channels_list[4], num_repeats[3], 1, True, inner_block=inner_block_list[3]) # stride 32 + self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) def forward(self, x): stem = self.stem(x) vajra1 = self.vajra_block1(stem) - pool1 = self.pool1(vajra1) - vajra2 = self.vajra_block2(pool1) + conv1 = self.conv1(vajra1) + vajra2 = self.vajra_block2(conv1) - pool2 = self.pool2(vajra2) - vajra3 = self.vajra_block3(pool2) + conv2 = self.conv2(vajra2) + vajra3 = self.vajra_block3(conv2) - pool3 = self.pool3(vajra3) - vajra4 = self.vajra_block4(pool3) - pyramid_pool_backbone = self.pyramid_pool(vajra4) #self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) + conv3 = self.conv3(vajra3) + vajra4 = self.vajra_block4(conv3) + pyramid_pool_backbone = self.pyramid_pool(vajra4) attn_block = self.attn_block(pyramid_pool_backbone) return attn_block @@ -297,7 +310,7 @@ def build_vajra(in_channels, if version != "v1" and version != "v3": config_dict = {"nano": [0.5, 0.5, 0.25, 1024], "small": [0.5, 0.5, 0.5, 1024], - "medium": [0.5, 0.50, 1.0, 512], + "medium": [0.5, 0.5, 1.0, 512], "large": [1.0, 1.0, 1.0, 512], "xlarge": [1.0, 1.0, 1.5, 512], } @@ -325,6 +338,14 @@ def build_vajra(in_channels, "large": [True, True, True, True, True, True, True, True], "xlarge": [True, True, True, True, True, True, True, True] } + vajra_v2_sanlayan_griva_config = { + "nano": False, + "small": False, + "medium": True, + "large": True, + "xlarge": True, + } + vajra_v2_sanlayan_griva = vajra_v2_sanlayan_griva_config[size] backbone_depth_mul = config_dict[size][0] neck_depth_mul = config_dict[size][1] width_mul = config_dict[size][2] @@ -347,7 +368,7 @@ def build_vajra(in_channels, if version == "v1": model = VajraV1Model(in_channels, channels_list, num_repeats, inner_blocks_list) if model_name.split("-")[1] != "deyo" else VajraV1DEYOModel(in_channels, vajra_deyo_channels_list, num_repeats, inner_blocks_list) elif version == "v2": - model = VajraV2Model(in_channels, channels_list, num_repeats) if model_name.split("-")[1] != "deyo" else VajraV2Model(in_channels, vajra_deyo_channels_list, num_repeats) + model = VajraV2Model(in_channels, channels_list, num_repeats, vajra_v2_sanlayan_griva, inner_blocks_list) if model_name.split("-")[1] != "deyo" else VajraV2Model(in_channels, vajra_deyo_channels_list, num_repeats, inner_blocks_list) elif version == "v3": model = VajraV3Model(in_channels, channels_list, num_repeats, inner_blocks_list) if model_name.split("-")[1] != "deyo" else VajraV3Model(in_channels, vajra_deyo_channels_list, num_repeats, inner_blocks_list) @@ -411,7 +432,12 @@ def build_vajra(in_channels, return vajra, stride, layers, np_model else: - model = VajraV1CLSModel(in_channels=3, channels_list=channels_list, num_repeats=num_repeats) if version == "v1" else VajraV2CLSModel(in_channels=3, channels_list=channels_list, num_repeats=num_repeats) + if version == "v1": + model = VajraV1CLSModel(in_channels=3, channels_list=channels_list, num_repeats=num_repeats) + elif version == "v2": + model = VajraV2CLSModel(in_channels=3, channels_list=channels_list, num_repeats=num_repeats) + elif version == "v3": + model = VajraV3CLSModel(in_channels=3, channels_list=channels_list, num_repeats=num_repeats) np_model = sum(x.numel() for x in model.parameters()) head = Classification(in_c=channels_list[-1], out_c=num_classes) np_head = sum(x.numel() for x in head.parameters()) diff --git a/vajra/nn/vajrav2.py b/vajra/nn/vajrav2.py index 171ff7a..af3012d 100644 --- a/vajra/nn/vajrav2.py +++ b/vajra/nn/vajrav2.py @@ -8,7 +8,7 @@ from pathlib import Path from vajra.checks import check_suffix, check_requirements from vajra.utils.downloads import attempt_download_asset -from vajra.nn.modules import VajraStemBlock, VajraV2StemBlock, VajraStambh, SPPF, Concatenate, VajraStambhV2, VajraV2MerudandaBhag1, ADown, Bottleneck, MerudandaDW, VajraMerudandaBhag1, VajraMerudandaBhag7, VajraMerudandaBhag2, VajraMerudandaBhag3, VajraGrivaBhag1, VajraGrivaBhag2, VajraGrivaBhag3, VajraGrivaBhag4, VajraMerudandaBhag4, VajraMBConvBlock, VajraConvNeXtBlock, Sanlayan, ChatushtayaSanlayan, ConvBNAct, MaxPool, ImagePoolingAttention, VajraWindowAttnBottleneck, VajraV2BottleneckBlock, AttentionBottleneck, AttentionBottleneckV2 +from vajra.nn.modules import VajraStemBlock, VajraV2StemBlock, VajraStambh, SPPF, Concatenate, VajraStambhV2, VajraMerudandaV2Bhag1, VajraV2MerudandaBhag1, VajraGrivaV2Bhag1, VajraGrivaV2Bhag2, ADown, Bottleneck, MerudandaDW, VajraMerudandaBhag1, VajraMerudandaMS, VajraMerudandaBhag7, VajraMerudandaBhag2, VajraMerudandaBhag3, VajraMerudandaBhag5, VajraGrivaBhag1, VajraGrivaBhag2, VajraGrivaBhag3, VajraGrivaBhag4, VajraMerudandaBhag4, VajraMerudandaBhag7, VajraMBConvBlock, VajraConvNeXtBlock, Sanlayan, ChatushtayaSanlayan, ConvBNAct, MaxPool, ImagePoolingAttention, VajraWindowAttnBottleneck, VajraV2BottleneckBlock, AttentionBottleneck, AttentionBottleneckV3, AttentionBottleneckV2, AttentionBottleneckV4, AttentionBottleneckV6, SanlayanSPPFAttention from vajra.nn.head import Detection, OBBDetection, Segementation, Classification, PoseDetection, WorldDetection, Panoptic from vajra.utils import LOGGER, HYPERPARAMS_CFG_DICT, HYPERPARAMS_CFG_KEYS from vajra.utils.torch_utils import model_info, initialize_weights, fuse_conv_and_bn, time_sync, intersect_dicts, scale_img @@ -19,137 +19,56 @@ except ImportError: thop = None -"""class VajraV2Model(nn.Module): - def __init__(self, - in_channels = 3, - channels_list = [64, 128, 256, 512, 1024, 256, 256, 256, 256, 256, 256, 256, 256], - num_repeats=[2, 2, 2, 2, 2, 2, 2, 2], - ) -> None: - super().__init__() - self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], -1, [1, 3, 5, -1], -1, [1, 5, 3, -1], -1, [5 + sum(num_repeats[:4]), 6 + sum(num_repeats[:5]), -1], -1, -1, [6 + sum(num_repeats[:5]), 7 + sum(num_repeats[:6]), -1], -1, -1, [7 + sum(num_repeats[:6]), 9 + sum(num_repeats[:7]), 11 + sum(num_repeats)]] - # Backbone - self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.block1 = nn.Sequential(*[Bottleneck(channels_list[1], channels_list[1], True) for _ in range(num_repeats[0])]) # stride 4 - self.conv3 = ConvBNAct(channels_list[1], channels_list[2], 2, 3) - self.block2 = nn.Sequential(*[Bottleneck(channels_list[2], channels_list[2], True) for _ in range(num_repeats[1])]) # stride 8 - self.conv4 = ConvBNAct(channels_list[2], channels_list[3], 2, 3) - self.block3 = nn.Sequential(*[Bottleneck(channels_list[3], channels_list[3], True) for _ in range(num_repeats[2])]) # stride 16 - self.conv5 = ConvBNAct(channels_list[3], channels_list[4], 2, 3) - self.block4 = nn.Sequential(*[MerudandaDW(channels_list[4], channels_list[4], True, 0.5, True) for _ in range(num_repeats[3])]) # stride 32 - self.pyramid_pool = Sanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=1, use_cbam=False, expansion_ratio=1.0) - self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) - # Neck - self.fusion4cbam = ChatushtayaSanlayan(in_c=channels_list[1:5], out_c=channels_list[6], use_cbam=False, expansion_ratio=1.0) - self.vajra_neck1 = nn.Sequential(*[MerudandaDW(channels_list[6], channels_list[6], True) for _ in range(num_repeats[4])]) - - self.fusion4cbam2 = ChatushtayaSanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[6]], out_c=channels_list[8], use_cbam=False, expansion_ratio=1.0) - self.vajra_neck2 = nn.Sequential(*[Bottleneck(channels_list[8], channels_list[8], True) for _ in range(num_repeats[5])]) - - self.pyramid_pool_neck1 = Sanlayan(in_c=[channels_list[4], channels_list[6], channels_list[8]], out_c=channels_list[9], stride=1, use_cbam=False, expansion_ratio=1.0) - self.neck_conv1 = ConvBNAct(channels_list[9], channels_list[10], 2, 3) - self.vajra_neck3 = nn.Sequential(*[Bottleneck(channels_list[10], channels_list[10], True) for _ in range(num_repeats[6])]) - - self.pyramid_pool_neck2 = Sanlayan(in_c=[channels_list[6], channels_list[8], channels_list[10]], out_c=channels_list[11], stride=1, use_cbam=False, expansion_ratio=1.0) - self.neck_conv2 = ConvBNAct(channels_list[11], channels_list[12], 2, 3) - self.vajra_neck4 = nn.Sequential(*[Bottleneck(channels_list[12], channels_list[12], True) for _ in range(num_repeats[7])]) - - def forward(self, x): - # Backbone - conv1 = self.conv1(x) - conv2 = self.conv2(conv1) - vajra1 = self.block1(conv2) - vajra1 = conv2 + vajra1 - - pool1 = self.conv3(vajra1) - vajra2 = self.block2(pool1) - vajra2 = vajra2 + pool1 - - pool2 = self.conv4(vajra2) - vajra3 = self.block3(pool2) - vajra3 = vajra3 + pool2 - - pool3 = self.conv5(vajra3) - vajra4 = self.block4(pool3) - vajra4 = vajra4 + pool3 - - pyramid_pool_backbone = self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) - attn_block = self.attn_block(pyramid_pool_backbone) - # Neck - fusion4 = self.fusion4cbam([vajra1, vajra2, vajra3, attn_block]) - vajra_neck1 = self.vajra_neck1(fusion4) - vajra_neck1 = vajra_neck1 + fusion4 - #vajra_neck1 = vajra_neck1 + vajra3 - - fusion4_2 = self.fusion4cbam2([vajra1, vajra3, vajra2, vajra_neck1]) - vajra_neck2 = self.vajra_neck2(fusion4_2) - vajra_neck2 = vajra_neck2 + fusion4_2 - #vajra_neck2 = vajra_neck2 + vajra2 - - pyramid_pool_neck1 = self.pyramid_pool_neck1([attn_block, vajra_neck1, vajra_neck2]) - neck_conv1 = self.neck_conv1(pyramid_pool_neck1) - vajra_neck3 = self.vajra_neck3(neck_conv1) - vajra_neck3 = vajra_neck3 + neck_conv1 - #vajra_neck3 = vajra_neck3 + vajra3 - - pyramid_pool_neck2 = self.pyramid_pool_neck2([vajra_neck1, vajra_neck2, vajra_neck3]) - neck_conv2 = self.neck_conv2(pyramid_pool_neck2) - vajra_neck4 = self.vajra_neck4(neck_conv2) - vajra_neck4 = vajra_neck4 + neck_conv2 - #vajra_neck4 = vajra_neck4 + vajra4 - - outputs = [vajra_neck2, vajra_neck3, vajra_neck4] - return outputs -""" - class VajraV2Model(nn.Module): def __init__(self, in_channels = 3, channels_list = [64, 128, 256, 512, 1024, 256, 256, 256, 256, 256, 256, 256, 256], num_repeats=[2, 2, 2, 2, 2, 2, 2, 2], + sanlayan_griva = False, + inner_block_list=[False, False, True, True, False, False, False, True] ) -> None: super().__init__() - self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, [5, -1], -1, [3, -1], -1, -1, [11, -1], -1, -1, [9, -1], -1, [13, 16, 19]] + self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, [5, -1], -1, [3, -1], -1, -1, [11, -1], -1, -1, [9, -1], -1, [13, 16, 19]] # Backbone self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraMerudandaBhag4(channels_list[1], channels_list[2], num_repeats[0], True, 1, 0.25, True) # stride 4 - self.pool1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) - self.vajra_block2 = VajraMerudandaBhag4(channels_list[2], channels_list[3], num_repeats[1], True, 1, 0.25, True) # stride 8 - self.pool2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) - self.vajra_block3 = VajraMerudandaBhag4(channels_list[3], channels_list[4], num_repeats[2], True, 1, inner_block=True) # stride 16 - self.pool3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) - self.vajra_block4 = VajraMerudandaBhag4(channels_list[4], channels_list[4], num_repeats[3], True, 1, inner_block=True) # stride 32 - #self.sanlayan = Sanlayan(in_c=[channels_list[2], channels_list[3], channels_list[4], channels_list[4]], out_c=channels_list[4], stride=1, expansion_ratio=1.0) + self.vajra_block1 = VajraMerudandaV2Bhag1(channels_list[1], channels_list[2], num_blocks=num_repeats[0], inner_block=inner_block_list[0], kernel_size=3) # stride 4 + self.conv1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) + self.vajra_block2 = VajraMerudandaV2Bhag1(channels_list[2], channels_list[3], num_blocks=num_repeats[1], inner_block=inner_block_list[1], kernel_size=5) # stride 8 + self.conv2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) + self.vajra_block3 = VajraMerudandaV2Bhag1(channels_list[3], channels_list[4], num_blocks=num_repeats[2], inner_block=inner_block_list[2], kernel_size=7) # stride 16 + self.conv3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) + self.vajra_block4 = VajraMerudandaV2Bhag1(channels_list[4], channels_list[4], num_blocks=num_repeats[3], inner_block=inner_block_list[3], kernel_size=9) # stride 32 self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) - self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) + self.attn_block = AttentionBottleneckV6(channels_list[4], channels_list[4], 2) # Neck self.concat1 = Concatenate(in_c=[channels_list[4], channels_list[4]], dimension=1) - self.vajra_neck1 = VajraMerudandaBhag4(in_c=2 * channels_list[4], out_c=channels_list[6], num_blocks=num_repeats[4], kernel_size=1, shortcut=True, inner_block=True) + self.vajra_neck1 = VajraGrivaV2Bhag2(in_c=2 * channels_list[4], out_c=channels_list[6], num_blocks=num_repeats[4], inner_block=inner_block_list[4]) if not sanlayan_griva else VajraGrivaV2Bhag1(2 * channels_list[4], channels_list[6], num_repeats[4], inner_block_list[4]) #VajraMerudandaBhag7(in_c=2 * channels_list[4], out_c=channels_list[6], num_blocks=num_repeats[4], kernel_size=1, shortcut=True, inner_block=True) self.concat2 = Concatenate(in_c=[channels_list[6], channels_list[3]], dimension=1) - self.vajra_neck2 = VajraMerudandaBhag4(in_c=channels_list[6] + channels_list[3], out_c=channels_list[8], num_blocks=num_repeats[5], kernel_size=1, shortcut=True, inner_block=True) + self.vajra_neck2 = VajraGrivaV2Bhag2(in_c=channels_list[6] + channels_list[3], out_c=channels_list[8], num_blocks=num_repeats[5], inner_block=inner_block_list[5]) #VajraMerudandaBhag7(in_c=channels_list[6] + channels_list[3], out_c=channels_list[8], num_blocks=num_repeats[5], kernel_size=1, shortcut=True, inner_block=True) self.neck_conv1 = ConvBNAct(channels_list[8], channels_list[9], 2, 3) self.concat3 = Concatenate(in_c=[channels_list[6], channels_list[9]], dimension=1) - self.vajra_neck3 = VajraMerudandaBhag4(in_c=channels_list[6] + channels_list[9], out_c=channels_list[10], num_blocks=num_repeats[6], kernel_size=1, shortcut=True, inner_block=True) + self.vajra_neck3 = VajraGrivaV2Bhag2(in_c=channels_list[6] + channels_list[9], out_c=channels_list[10], num_blocks=num_repeats[6], inner_block=inner_block_list[6]) #VajraMerudandaBhag7(in_c=channels_list[6] + channels_list[9], out_c=channels_list[10], num_blocks=num_repeats[6], kernel_size=1, shortcut=True, inner_block=True) self.neck_conv2 = ConvBNAct(channels_list[10], channels_list[11], 2, 3) self.concat4 = Concatenate(in_c=[channels_list[11], channels_list[4]], dimension=1) - self.vajra_neck4 = VajraMerudandaBhag4(in_c=channels_list[4] + channels_list[11], out_c=channels_list[12], num_blocks=num_repeats[7], kernel_size=1, shortcut=True, inner_block=True) + self.vajra_neck4 = VajraGrivaV2Bhag2(in_c=channels_list[4] + channels_list[11], out_c=channels_list[12], num_blocks=num_repeats[7], inner_block=inner_block_list[7]) if not sanlayan_griva else VajraGrivaV2Bhag1(channels_list[4] + channels_list[11], channels_list[12], num_repeats[7], inner_block_list[7]) #VajraMerudandaBhag7(in_c=channels_list[4] + channels_list[11], out_c=channels_list[12], num_blocks=num_repeats[7], kernel_size=1, shortcut=True, inner_block=True) def forward(self, x): # Backbone stem = self.stem(x) vajra1 = self.vajra_block1(stem) - pool1 = self.pool1(vajra1) - vajra2 = self.vajra_block2(pool1) + conv1 = self.conv1(vajra1) + vajra2 = self.vajra_block2(conv1) - pool2 = self.pool2(vajra2) - vajra3 = self.vajra_block3(pool2) + conv2 = self.conv2(vajra2) + vajra3 = self.vajra_block3(conv2) - pool3 = self.pool3(vajra3) - vajra4 = self.vajra_block4(pool3) - #sanlayan = self.sanlayan([vajra1, vajra2, vajra3, vajra4]) + conv3 = self.conv3(vajra3) + vajra4 = self.vajra_block4(conv3) + pyramidal_pool = self.pyramid_pool(vajra4) attn_block = self.attn_block(pyramidal_pool) # Neck @@ -176,106 +95,37 @@ def forward(self, x): outputs = [vajra_neck2, vajra_neck3, vajra_neck4] return outputs - -"""class VajraV2Model(nn.Module): - def __init__(self, - in_channels = 3, - channels_list = [64, 128, 256, 512, 1024, 256, 256, 256, 256, 256, 256, 256, 256], - num_repeats=[3, 6, 6, 3, 3, 3, 3, 3], - ) -> None: - super().__init__() - self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], [1, 3, 5, -1], -1, [1, 5, 3, -1], -1, [8, 10, -1], -1, [10, 12, -1], -1, [12, 14, 16]] - # Backbone - self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraMerudandaBhag7(channels_list[1], channels_list[1], num_repeats[0], True, 3, False, 0.5, False) # stride 4 - self.pool1 = MaxPool(kernel_size=2, stride=2) - self.vajra_block2 = VajraMerudandaBhag7(channels_list[1], channels_list[2], num_repeats[1], True, 3, False, 0.5, False) # stride 8 - self.pool2 = MaxPool(kernel_size=2, stride=2) - self.vajra_block3 = VajraMerudandaBhag7(channels_list[2], channels_list[3], num_repeats[2], True, 3, True, 0.5, False) # stride 16 - self.pool3 = MaxPool(kernel_size=2, stride=2) - self.vajra_block4 = VajraMerudandaBhag7(channels_list[3], channels_list[4], num_repeats[3], True, 3, True, 0.5, False) # stride 32 - self.pyramid_pool = Sanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=2, use_cbam=False, expansion_ratio=1.0) - - # Neck - self.fusion4cbam = ChatushtayaSanlayan(in_c=channels_list[1:5], out_c=channels_list[6], use_cbam=False, expansion_ratio=0.5) - self.vajra_neck1 = VajraGrivaBhag1(channels_list[6], num_repeats[4], 1, 0.5, False, True) - - self.fusion4cbam2 = ChatushtayaSanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[6]], out_c=channels_list[8], use_cbam=False, expansion_ratio=0.5) - self.vajra_neck2 = VajraGrivaBhag1(channels_list[8], num_repeats[5], 1, 0.5, False) - - self.fusion4cbam3 = ChatushtayaSanlayan(in_c=[channels_list[4], channels_list[6], channels_list[1], channels_list[8]], out_c=channels_list[8], use_cbam=False, expansion_ratio=0.5) - self.vajra_neck3 = VajraGrivaBhag1(channels_list[8], num_repeats[5], 1, 0.5, False) - - self.pyramid_pool_neck1 = Sanlayan(in_c=[channels_list[4], channels_list[6], channels_list[8]], out_c=channels_list[10], stride=2, use_cbam=False, expansion_ratio=0.5) - self.vajra_neck4 = VajraGrivaBhag1(channels_list[10], num_repeats[6], 1, 0.5, False, True) - - self.pyramid_pool_neck2 = Sanlayan(in_c=[channels_list[6], channels_list[8], channels_list[10]], out_c=channels_list[12], stride=2, use_cbam=False, expansion_ratio=0.5) - self.vajra_neck5 = VajraGrivaBhag1(channels_list[12], num_repeats[7], 1, 0.5, False, True) - - def forward(self, x): - # Backbone - stem = self.stem(x) - vajra1 = self.vajra_block1(stem) - - pool1 = self.pool1(vajra1) - vajra2 = self.vajra_block2(pool1) - - pool2 = self.pool2(vajra2) - vajra3 = self.vajra_block3(pool2) - - pool3 = self.pool3(vajra3) - vajra4 = self.vajra_block4(pool3) - pyramid_pool_backbone = self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) - - # Neck - fusion4 = self.fusion4cbam([vajra1, vajra2, vajra3, pyramid_pool_backbone]) - vajra_neck1 = self.vajra_neck1(fusion4) - vajra_neck1 = vajra_neck1 + vajra3 - - fusion4_2 = self.fusion4cbam2([vajra1, vajra3, vajra2, vajra_neck1]) - vajra_neck2 = self.vajra_neck2(fusion4_2) - vajra_neck2 = vajra_neck2 + vajra2 - - pyramid_pool_neck1 = self.pyramid_pool_neck1([pyramid_pool_backbone, vajra_neck1, vajra_neck2]) - vajra_neck4 = self.vajra_neck4(pyramid_pool_neck1) - vajra_neck4 = vajra_neck4 + vajra3 - - pyramid_pool_neck2 = self.pyramid_pool_neck2([vajra_neck1, vajra_neck2, vajra_neck4]) - vajra_neck5 = self.vajra_neck5(pyramid_pool_neck2) - vajra_neck5 = vajra_neck5 + vajra4 - - outputs = [vajra_neck2, vajra_neck4, vajra_neck4] - return outputs""" class VajraV2CLSModel(nn.Module): def __init__(self, in_channels=3, channels_list=[64, 128, 256, 512, 1024], - num_repeats=[3, 6, 6, 3]) -> None: + num_repeats=[2, 2, 2, 2], + inner_block_list=[False, False, True, True, False, False, False, True]) -> None: super().__init__() self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], -1] self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraV2BottleneckBlock(channels_list[1], channels_list[1], num_repeats[0], 1, True, 3, False) # stride 4 - self.pool1 = MaxPool(kernel_size=2, stride=2) - self.vajra_block2 = VajraV2BottleneckBlock(channels_list[1], channels_list[2], num_repeats[1], 1, True, 3, False) # stride 8 - self.pool2 = MaxPool(kernel_size=2, stride=2) - self.vajra_block3 = VajraV2BottleneckBlock(channels_list[2], channels_list[3], num_repeats[2], 1, True, 3, False) # stride 16 - self.pool3 = MaxPool(kernel_size=2, stride=2) - self.vajra_block4 = VajraV2BottleneckBlock(channels_list[3], channels_list[4], num_repeats[3], 1, True, 3, False) # stride 32 - self.pyramid_pool = Sanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=2) + self.vajra_block1 = VajraMerudandaV2Bhag1(channels_list[1], channels_list[2], num_blocks=num_repeats[0], inner_block=inner_block_list[0], kernel_size=3) # stride 4 + self.conv1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) + self.vajra_block2 = VajraMerudandaV2Bhag1(channels_list[2], channels_list[3], num_blocks=num_repeats[1], inner_block=inner_block_list[1], kernel_size=5) # stride 8 + self.conv2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) + self.vajra_block3 = VajraMerudandaV2Bhag1(channels_list[3], channels_list[4], num_blocks=num_repeats[2], inner_block=inner_block_list[2], kernel_size=7) # stride 16 + self.conv3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) + self.vajra_block4 = VajraMerudandaV2Bhag1(channels_list[4], channels_list[4], num_blocks=num_repeats[3], inner_block=inner_block_list[3], kernel_size=9) # stride 32 + self.pyramid_pool_attn = SanlayanSPPFAttention(2 * channels_list[4], channels_list[4], 1, 2) def forward(self, x): stem = self.stem(x) vajra1 = self.vajra_block1(stem) - pool1 = self.pool1(vajra1) - vajra2 = self.vajra_block2(pool1) + conv1 = self.conv1(vajra1) + vajra2 = self.vajra_block2(conv1) - pool2 = self.pool2(vajra2) - vajra3 = self.vajra_block3(pool2) + conv2 = self.conv2(vajra2) + vajra3 = self.vajra_block3(conv2) - pool3 = self.pool3(vajra3) - vajra4 = self.vajra_block4(pool3) - pyramid_pool_backbone = self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) + conv3 = self.conv3(vajra3) + vajra4 = self.vajra_block4(conv3) + pyramid_pool_backbone = self.pyramid_pool_attn([vajra3, vajra4]) return pyramid_pool_backbone \ No newline at end of file diff --git a/vajra/nn/vajrav3.py b/vajra/nn/vajrav3.py index 8c57a6b..5bae544 100644 --- a/vajra/nn/vajrav3.py +++ b/vajra/nn/vajrav3.py @@ -8,7 +8,7 @@ from pathlib import Path from vajra.checks import check_suffix, check_requirements from vajra.utils.downloads import attempt_download_asset -from vajra.nn.modules import VajraStemBlock, VajraV2StemBlock, VajraV3StemBlock, VajraStambh, VajraMerudandaBhag1, VajraMerudandaBhag3, VajraMerudandaBhag6, VajraMerudandaBhag7, VajraGrivaBhag1, VajraGrivaBhag2, VajraMerudandaBhag2, VajraMBConvBlock, VajraConvNeXtBlock, Sanlayan, SPPF, Concatenate, Upsample, SanlayanSPPF, ChatushtayaSanlayan, TritayaSanlayan, AttentionBottleneck, ConvBNAct, DepthwiseConvBNAct, MaxPool, ImagePoolingAttention, VajraWindowAttnBottleneck, VajraV2BottleneckBlock, VajraV3BottleneckBlock, ADown +from vajra.nn.modules import VajraStemBlock, VajraV2StemBlock, VajraV3StemBlock, VajraStambh, VajraMerudandaBhag1, VajraMerudandaBhag3, VajraMerudandaBhag5, VajraMerudandaBhag6, VajraMerudandaBhag7, VajraGrivaBhag1, VajraGrivaBhag2, VajraMerudandaBhag2, VajraMerudandaBhag4, VajraMBConvBlock, VajraConvNeXtBlock, Sanlayan, SPPF, Concatenate, Upsample, SanlayanSPPF, ChatushtayaSanlayan, TritayaSanlayan, AttentionBottleneck, AttentionBottleneckV2, AttentionBottleneckV4, ConvBNAct, DepthwiseConvBNAct, MaxPool, ImagePoolingAttention, VajraWindowAttnBottleneck, VajraV2BottleneckBlock, VajraV3BottleneckBlock, ADown from vajra.nn.head import Detection, OBBDetection, Segementation, Classification, PoseDetection, WorldDetection, Panoptic from vajra.utils import LOGGER, HYPERPARAMS_CFG_DICT, HYPERPARAMS_CFG_KEYS from vajra.utils.torch_utils import model_info, initialize_weights, fuse_conv_and_bn, time_sync, intersect_dicts, scale_img @@ -18,85 +18,6 @@ import thop except ImportError: thop = None - -"""class VajraV3Model(nn.Module): - def __init__(self, - in_channels = 3, - channels_list = [64, 128, 256, 512, 1024, 256, 256, 256, 256, 256, 256, 256, 256], - num_repeats=[2, 2, 2, 2, 2, 2, 2, 2], - ) -> None: - super().__init__() - self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], [1, 3, 5, -1], -1, [1, 5, 3, -1], -1, [8, 10, -1], -1, [10, 12, -1], -1, [12, 14, 16]] - # Backbone - self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraMerudandaBhag1(channels_list[1], channels_list[1], num_repeats[0], True, 3, False) # stride 4 - #self.conv1 = ConvBNAct(channels_list[1], channels_list[2], 2, 3) - self.pool1 = MaxPool(2, 2) - self.vajra_block2 = VajraMerudandaBhag1(channels_list[1], channels_list[2], num_repeats[1], True, 3, False) # stride 8 - #self.conv2 = ConvBNAct(channels_list[2], channels_list[3], 2, 3) - self.pool2 = MaxPool(2, 2) - self.vajra_block3 = VajraMerudandaBhag1(channels_list[2], channels_list[3], num_repeats[2], True, 3, expansion_ratio=0.5, bottleneck_dwcib=True) # stride 16 - #self.conv3 = ConvBNAct(channels_list[3], channels_list[4], 2, 3) - self.pool3 = MaxPool(2, 2) - self.vajra_block4 = VajraMerudandaBhag1(channels_list[3], channels_list[4], num_repeats[3], True, 3, expansion_ratio=0.5, bottleneck_dwcib=True) # stride 32 - self.pyramid_pool = SanlayanSPPF(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=2, expansion_ratio=1.0) - #self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2, 1) - # Neck - self.fusion4cbam = ChatushtayaSanlayan(in_c=channels_list[1:5], out_c=channels_list[6], use_cbam=False, expansion_ratio=0.5) - self.vajra_neck1 = VajraGrivaBhag1(channels_list[6], num_repeats[4], 1, 0.5, False, True) - - self.fusion4cbam2 = ChatushtayaSanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[6]], out_c=channels_list[8], use_cbam=False, expansion_ratio=0.5) - self.vajra_neck2 = VajraGrivaBhag1(channels_list[8], num_repeats[5], 1, 0.5, False) - - self.pyramid_pool_neck1 = Sanlayan(in_c=[channels_list[4], channels_list[6], channels_list[8]], out_c=channels_list[10], stride=2, use_cbam=False, expansion_ratio=0.5) - #self.neck_conv1 = ConvBNAct(channels_list[9], channels_list[10], 2, 3) - self.vajra_neck3 = VajraGrivaBhag1(channels_list[10], num_repeats[6], 1, 0.5, False, True) - - self.pyramid_pool_neck2 = Sanlayan(in_c=[channels_list[6], channels_list[8], channels_list[10]], out_c=channels_list[12], stride=2, use_cbam=False, expansion_ratio=0.5) - #self.neck_conv2 = ConvBNAct(channels_list[11], channels_list[12], 2, 3) - self.vajra_neck4 = VajraGrivaBhag1(channels_list[12], num_repeats[7], 1, 0.5, False, True) - - def forward(self, x): - # Backbone - stem = self.stem(x) - vajra1 = self.vajra_block1(stem) - - pool1 = self.pool1(vajra1) - #pool1 = self.pool1(conv1) - vajra2 = self.vajra_block2(pool1) - - - pool2 = self.pool2(vajra2) - #pool2 = self.pool2(conv2) - vajra3 = self.vajra_block3(pool2) - - pool3 = self.pool3(vajra3) - #pool3 = self.pool3(conv3) - vajra4 = self.vajra_block4(pool3) - pyramid_pool_backbone = self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) - #attn_block = self.attn_block(pyramid_pool_backbone) - # Neck - fusion4 = self.fusion4cbam([vajra1, vajra2, vajra3, pyramid_pool_backbone]) - vajra_neck1 = self.vajra_neck1(fusion4) - vajra_neck1 = vajra_neck1 + vajra3 - - fusion4_2 = self.fusion4cbam2([vajra1, vajra3, vajra2, vajra_neck1]) - vajra_neck2 = self.vajra_neck2(fusion4_2) - vajra_neck2 = vajra_neck2 + vajra2 - - pyramid_pool_neck1 = self.pyramid_pool_neck1([pyramid_pool_backbone, vajra_neck1, vajra_neck2]) - #neck_conv1 = self.neck_conv1(pyramid_pool_neck1) - vajra_neck3 = self.vajra_neck3(pyramid_pool_neck1) - vajra_neck3 = vajra_neck3 + vajra3 - - pyramid_pool_neck2 = self.pyramid_pool_neck2([vajra_neck1, vajra_neck2, vajra_neck3]) - #neck_conv2 = self.neck_conv2(pyramid_pool_neck2) - vajra_neck4 = self.vajra_neck4(pyramid_pool_neck2) - vajra_neck4 = vajra_neck4 + vajra4 - - outputs = [vajra_neck2, vajra_neck3, vajra_neck4] - return outputs -""" class VajraV3Model(nn.Module): def __init__(self, @@ -106,40 +27,41 @@ def __init__(self, inner_block_list=[False, False, True, True, False, False, False, True] ) -> None: super().__init__() - self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], -1, -1, [5, -1], -1, -1, [3, -1], -1, -1, [11, -1], -1, -1, [9, -1], -1, [13, 16, 19]] + self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, [5, -1], -1, -1, [3, -1], -1, -1, [11, -1], -1, -1, [9, -1], -1, [13, 16, 19]] # Backbone - self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraMerudandaBhag3(channels_list[1], channels_list[2], num_repeats[0], 1, True, 0.25, False, inner_block_list[0]) # stride 4 + self.conv1 = ConvBNAct(in_channels, channels_list[0]) #VajraStambh(in_channels, channels_list[0], channels_list[1]) + self.conv2 = ConvBNAct(channels_list[0], channels_list[1], 2, 3) + self.vajra_block1 = VajraMerudandaBhag5(channels_list[1], channels_list[2], num_repeats[0], True, 1, 0.25, inner_block_list[0]) # stride 4 self.pool1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) - self.vajra_block2 = VajraMerudandaBhag3(channels_list[2], channels_list[3], num_repeats[1], 1, True, 0.25, False, inner_block_list[1]) # stride 8 + self.vajra_block2 = VajraMerudandaBhag5(channels_list[2], channels_list[3], num_repeats[1], True, 1, 0.25, inner_block=inner_block_list[1]) # stride 8 self.pool2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) - self.vajra_block3 = VajraMerudandaBhag3(channels_list[3], channels_list[4], num_repeats[2], 1, True, inner_block=inner_block_list[2]) # stride 16 + self.vajra_block3 = VajraMerudandaBhag5(channels_list[3], channels_list[4], num_repeats[2], True, 1, inner_block=inner_block_list[2]) # stride 16 self.pool3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) - self.vajra_block4 = VajraMerudandaBhag3(channels_list[4], channels_list[4], num_repeats[3], 1, True, inner_block=inner_block_list[3]) # stride 32 - self.sanlayan = Sanlayan(in_c=[channels_list[2], channels_list[3], channels_list[4], channels_list[4]], out_c=channels_list[4], stride=1, expansion_ratio=1.0) + self.vajra_block4 = VajraMerudandaBhag5(channels_list[4], channels_list[4], num_repeats[3], True, 1, inner_block=inner_block_list[3]) # stride 32 self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) - #self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) + self.attn_block = AttentionBottleneckV4(channels_list[4], channels_list[4], 2) # Neck self.upsample1 = Upsample(2, "nearest") self.concat1 = Concatenate(in_c=[channels_list[4], channels_list[4]], dimension=1) - self.vajra_neck1 = VajraMerudandaBhag3(in_c=2 * channels_list[4], out_c=channels_list[6], num_blocks=num_repeats[4], kernel_size=1, shortcut=True, inner_block=inner_block_list[4]) + self.vajra_neck1 = VajraMerudandaBhag5(in_c=2 * channels_list[4], out_c=channels_list[6], num_blocks=num_repeats[4], kernel_size=1, shortcut=True, inner_block=inner_block_list[4]) self.upsample2 = Upsample(2, "nearest") self.concat2 = Concatenate(in_c=[channels_list[6], channels_list[3]], dimension=1) - self.vajra_neck2 = VajraMerudandaBhag3(in_c=channels_list[6] + channels_list[3], out_c=channels_list[8], num_blocks=num_repeats[5], kernel_size=1, shortcut=True, inner_block=inner_block_list[5]) + self.vajra_neck2 = VajraMerudandaBhag5(in_c=channels_list[6] + channels_list[3], out_c=channels_list[8], num_blocks=num_repeats[5], kernel_size=1, shortcut=True, inner_block=inner_block_list[5]) self.neck_conv1 = ConvBNAct(channels_list[8], channels_list[9], 2, 3) self.concat3 = Concatenate(in_c=[channels_list[6], channels_list[9]], dimension=1) - self.vajra_neck3 = VajraMerudandaBhag3(in_c=channels_list[6] + channels_list[9], out_c=channels_list[10], num_blocks=num_repeats[6], kernel_size=1, shortcut=True, inner_block=inner_block_list[6]) + self.vajra_neck3 = VajraMerudandaBhag5(in_c=channels_list[6] + channels_list[9], out_c=channels_list[10], num_blocks=num_repeats[6], kernel_size=1, shortcut=True, inner_block=inner_block_list[6]) self.neck_conv2 = ConvBNAct(channels_list[10], channels_list[11], 2, 3) self.concat4 = Concatenate(in_c=[channels_list[11], channels_list[4]], dimension=1) - self.vajra_neck4 = VajraMerudandaBhag3(in_c=channels_list[4] + channels_list[11], out_c=channels_list[12], num_blocks=num_repeats[7], kernel_size=1, shortcut=True, inner_block=inner_block_list[7]) + self.vajra_neck4 = VajraMerudandaBhag5(in_c=channels_list[4] + channels_list[11], out_c=channels_list[12], num_blocks=num_repeats[7], kernel_size=1, shortcut=True, inner_block=inner_block_list[7]) def forward(self, x): # Backbone - stem = self.stem(x) - vajra1 = self.vajra_block1(stem) + conv1 = self.conv1(x) + conv2 = self.conv2(conv1) + vajra1 = self.vajra_block1(conv2) pool1 = self.pool1(vajra1) vajra2 = self.vajra_block2(pool1) @@ -149,11 +71,11 @@ def forward(self, x): pool3 = self.pool3(vajra3) vajra4 = self.vajra_block4(pool3) - sanlayan = self.sanlayan([vajra1, vajra2, vajra3, vajra4]) #self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) - pyramid_pool = self.pyramid_pool(sanlayan) + pyramid_pool_backbone = self.pyramid_pool(vajra4) #self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) + attn_block = self.attn_block(pyramid_pool_backbone) # Neck #_, _, H3, W3 = vajra3.shape - neck_upsample1 = self.upsample1(pyramid_pool) #F.interpolate(attn_block, size=(H3, W3), mode="nearest") + neck_upsample1 = self.upsample1(attn_block) #F.interpolate(attn_block, size=(H3, W3), mode="nearest") concat_neck1 = self.concat1([vajra3, neck_upsample1]) vajra_neck1 = self.vajra_neck1(concat_neck1) vajra_neck1 = vajra_neck1 + vajra3 if self.vajra_neck1.out_c == self.vajra_block3.out_c else vajra_neck1 @@ -169,9 +91,9 @@ def forward(self, x): vajra_neck3 = vajra_neck3 + vajra3 if self.vajra_neck3.out_c == self.vajra_block3.out_c else vajra_neck3 neck_conv2 = self.neck_conv2(vajra_neck3) - concat_neck4 = self.concat4([pyramid_pool, neck_conv2]) + concat_neck4 = self.concat4([attn_block, neck_conv2]) vajra_neck4 = self.vajra_neck4(concat_neck4) - vajra_neck4 = vajra_neck4 + pyramid_pool + vajra_neck4 = vajra_neck4 + attn_block outputs = [vajra_neck2, vajra_neck3, vajra_neck4] return outputs @@ -180,18 +102,22 @@ class VajraV3CLSModel(nn.Module): def __init__(self, in_channels=3, channels_list=[64, 128, 256, 512, 1024], - num_repeats=[3, 6, 6, 3]) -> None: + num_repeats=[2, 2, 2, 2], + inner_block_list=[False, False, True, True, False, False, False, True] + ) -> None: super().__init__() self.from_list = [-1, -1, -1, -1, -1, -1, -1, -1, [1, 3, 5, -1], -1] self.stem = VajraStambh(in_channels, channels_list[0], channels_list[1]) - self.vajra_block1 = VajraV3BottleneckBlock(channels_list[1], channels_list[1], num_repeats[0], 3, True, 3, False) # stride 4 - self.pool1 = MaxPool(kernel_size=2, stride=2) - self.vajra_block2 = VajraV3BottleneckBlock(channels_list[1], channels_list[2], num_repeats[1], 3, True, 3, False) # stride 8 - self.pool2 = MaxPool(kernel_size=2, stride=2) - self.vajra_block3 = VajraV3BottleneckBlock(channels_list[2], channels_list[3], num_repeats[2], 3, True, 3, False) # stride 16 - self.pool3 = MaxPool(kernel_size=2, stride=2) - self.vajra_block4 = VajraV3BottleneckBlock(channels_list[3], channels_list[4], num_repeats[3], 3, True, 3, False) # stride 32 - self.pyramid_pool = Sanlayan(in_c=[channels_list[1], channels_list[2], channels_list[3], channels_list[4]], out_c=channels_list[4], stride=2) + self.vajra_block1 = VajraMerudandaBhag3(channels_list[1], channels_list[2], num_repeats[0], 1, True, 0.25, False, inner_block_list[0]) # stride 4 + self.pool1 = ConvBNAct(channels_list[2], channels_list[2], 2, 3) + self.vajra_block2 = VajraMerudandaBhag3(channels_list[2], channels_list[3], num_repeats[1], 1, True, 0.25, False, inner_block_list[1]) # stride 8 + self.pool2 = ConvBNAct(channels_list[3], channels_list[3], 2, 3) + self.vajra_block3 = VajraMerudandaBhag3(channels_list[3], channels_list[4], num_repeats[2], 1, True, inner_block=inner_block_list[2]) # stride 16 + self.pool3 = ConvBNAct(channels_list[4], channels_list[4], 2, 3) + self.vajra_block4 = VajraMerudandaBhag3(channels_list[4], channels_list[4], num_repeats[3], 1, True, inner_block=inner_block_list[3]) # stride 32 + self.sanlayan = Sanlayan(in_c=[channels_list[2], channels_list[3], channels_list[4], channels_list[4]], out_c=channels_list[4], stride=1, expansion_ratio=1.0) + self.pyramid_pool = SPPF(channels_list[4], channels_list[4]) + self.attn_block = AttentionBottleneck(channels_list[4], channels_list[4], 2) def forward(self, x): stem = self.stem(x) @@ -206,5 +132,6 @@ def forward(self, x): pool3 = self.pool3(vajra3) vajra4 = self.vajra_block4(pool3) pyramid_pool_backbone = self.pyramid_pool([vajra1, vajra2, vajra3, vajra4]) + attn_block = self.attn_block(pyramid_pool_backbone) - return pyramid_pool_backbone \ No newline at end of file + return attn_block \ No newline at end of file