Skip to content

Commit 690b1aa

Browse files
committed
- update code
1 parent faf5bee commit 690b1aa

File tree

9 files changed

+512
-1
lines changed

9 files changed

+512
-1
lines changed
Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,47 @@
1+
import torch
12
import torch.nn as nn
23

34
from evo_science.entities.models.abstract_model import AbstractModel
45

56

67
class AbstractTorchModel(nn.Module, AbstractModel):
7-
pass
8+
9+
def load_weight(self, checkpoint_path: str):
10+
"""
11+
Load weights from a checkpoint file.
12+
13+
Args:
14+
checkpoint_path (str): Path to the checkpoint file.
15+
16+
Returns:
17+
self: The model instance with loaded weights.
18+
"""
19+
# Load the current model state
20+
model_state = self.state_dict()
21+
22+
# Load the checkpoint
23+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
24+
checkpoint_state = checkpoint["model"].float().state_dict()
25+
26+
# Filter and load matching weights
27+
compatible_weights = {
28+
k: v for k, v in checkpoint_state.items() if k in model_state and v.shape == model_state[k].shape
29+
}
30+
31+
# Update the model with compatible weights
32+
self.load_state_dict(compatible_weights, strict=False)
33+
34+
return self
35+
36+
def get_criterion(self):
37+
raise NotImplementedError("This method must be implemented in the subclass.")
38+
39+
def clip_gradients(self, max_norm=10.0):
40+
"""
41+
Clip gradients of the model's parameters.
42+
43+
Args:
44+
max_norm (float): The maximum norm value for gradient clipping. Default is 10.0.
45+
"""
46+
parameters = self.parameters()
47+
nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)

evo_science/entities/metrics/iou.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from evo_science.entities.metrics.base_metric import BaseMetric
2+
import torch
3+
import numpy as np
4+
5+
6+
class IOU(BaseMetric):
7+
name = "Intersection over Union"
8+
9+
def _calculate_np(self, y_true, y_pred):
10+
return np.mean(np.diag(y_true @ y_pred.T))
11+
12+
@staticmethod
13+
def compute_iou(box1, box2, eps=1e-7):
14+
# Returns Complete Intersection over Union (CIoU) of box1(1,4) to box2(n,4)
15+
16+
# Get the coordinates of bounding boxes
17+
b1_x1, b1_y1, b1_x2, b1_y2 = box1.unbind(-1)
18+
b2_x1, b2_y1, b2_x2, b2_y2 = box2.unbind(-1)
19+
20+
# Calculate width and height of boxes
21+
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1
22+
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1
23+
24+
# Calculate intersection area
25+
inter = torch.clamp((torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)), min=0) * torch.clamp(
26+
(torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)), min=0
27+
)
28+
29+
# Calculate union area
30+
union = w1 * h1 + w2 * h2 - inter + eps
31+
32+
# Calculate IoU
33+
iou = inter / union
34+
35+
# Calculate the convex (smallest enclosing box) width and height
36+
cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1)
37+
ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1)
38+
39+
# Calculate diagonal distance
40+
c2 = cw.pow(2) + ch.pow(2) + eps
41+
42+
# Calculate center distance
43+
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)) / 4
44+
45+
# Calculate aspect ratio consistency term
46+
v = (4 / (torch.pi**2)) * (torch.atan(w2 / (h2 + eps)) - torch.atan(w1 / (h1 + eps))).pow(2)
47+
48+
# Calculate alpha for CIoU
49+
with torch.no_grad():
50+
alpha = v / (v - iou + (1 + eps))
51+
52+
# Return CIoU
53+
return iou - (rho2 / c2 + v * alpha)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import torch
2+
from torch import nn
3+
import copy
4+
from typing import Callable
5+
6+
7+
class ExponentialMovingAverage:
8+
"""
9+
Exponential Moving Average (EMA) implementation.
10+
11+
Maintains a moving average of the model's parameters and buffers.
12+
Reference:
13+
- https://github.com/rwightman/pytorch-image-models
14+
- https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
15+
"""
16+
17+
def __init__(self, model: nn.Module, decay: float = 0.9999, tau: float = 2000, updates: int = 0):
18+
self.ema_model = copy.deepcopy(model).eval()
19+
self.update_count = updates
20+
self.decay_fn = self._create_decay_function(decay, tau)
21+
self._freeze_ema_params()
22+
23+
def _create_decay_function(self, decay: float, tau: float) -> Callable[[int], float]:
24+
return lambda x: decay * (1 - torch.exp(torch.tensor(-x / tau)).item())
25+
26+
def _freeze_ema_params(self):
27+
for param in self.ema_model.parameters():
28+
param.requires_grad_(False)
29+
30+
def update(self, model: nn.Module):
31+
if hasattr(model, "module"):
32+
model = model.module
33+
34+
with torch.no_grad():
35+
self.update_count += 1
36+
current_decay = self.decay_fn(self.update_count)
37+
38+
for ema_param, model_param in zip(self.ema_model.state_dict().values(), model.state_dict().values()):
39+
if ema_param.dtype.is_floating_point:
40+
ema_param.mul_(current_decay).add_(model_param.detach(), alpha=1 - current_decay)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import numpy as np
2+
3+
4+
class AverageMeter:
5+
def __init__(self):
6+
self.reset()
7+
8+
def reset(self):
9+
self.count = 0
10+
self.total = 0
11+
self.average = 0
12+
13+
def update(self, value, n=1):
14+
if not np.isnan(value):
15+
self.count += n
16+
self.total += value * n
17+
self.average = self.total / self.count
18+
19+
@property
20+
def avg(self):
21+
return self.average

evo_science/entities/utils/nms.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import torch
2+
import torchvision
3+
from time import time
4+
5+
6+
class NonMaxSuppression:
7+
def __init__(self, conf_threshold, iou_threshold, max_wh=7680, max_det=300, max_nms=30000):
8+
self.conf_threshold = conf_threshold
9+
self.iou_threshold = iou_threshold
10+
self.max_wh = max_wh
11+
self.max_det = max_det
12+
self.max_nms = max_nms
13+
14+
def __call__(self, outputs):
15+
bs = outputs.shape[0]
16+
nc = outputs.shape[1] - 4
17+
xc = outputs[:, 4 : 4 + nc].amax(1) > self.conf_threshold
18+
19+
start = time()
20+
limit = 0.5 + 0.05 * bs
21+
22+
output = [torch.zeros((0, 6), device=outputs.device)] * bs
23+
for index, x in enumerate(outputs):
24+
x = x.transpose(0, -1)[xc[index]]
25+
26+
if not x.shape[0]:
27+
continue
28+
29+
x = self._process_candidates(x, nc)
30+
31+
if not x.shape[0]:
32+
continue
33+
x = x[x[:, 4].argsort(descending=True)[: self.max_nms]]
34+
35+
x = self._batched_nms(x)
36+
37+
output[index] = x
38+
if (time() - start) > limit:
39+
break
40+
41+
return output
42+
43+
def _process_candidates(self, x, nc):
44+
box, cls = x.split((4, nc), 1)
45+
box = self._wh2xy(box)
46+
if nc > 1:
47+
i, j = (cls > self.conf_threshold).nonzero(as_tuple=False).T
48+
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
49+
else:
50+
conf, j = cls.max(1, keepdim=True)
51+
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > self.conf_threshold]
52+
return x
53+
54+
def _batched_nms(self, x):
55+
c = x[:, 5:6] * self.max_wh
56+
boxes, scores = x[:, :4] + c, x[:, 4]
57+
i = torchvision.ops.nms(boxes, scores, self.iou_threshold)
58+
return x[i[: self.max_det]]
59+
60+
@staticmethod
61+
def _wh2xy(x):
62+
y = x.clone()
63+
y[:, 0] = x[:, 0] - x[:, 2] / 2
64+
y[:, 1] = x[:, 1] - x[:, 3] / 2
65+
y[:, 2] = x[:, 0] + x[:, 2] / 2
66+
y[:, 3] = x[:, 1] + x[:, 3] / 2
67+
return y
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from torch import nn
2+
import torch
3+
4+
5+
class Assigner(nn.Module):
6+
def __init__(self, top_k=13, nc=80, alpha=1.0, beta=6.0, eps=1e-9):
7+
super().__init__()
8+
self.top_k = top_k
9+
self.nc = nc
10+
self.alpha = alpha
11+
self.beta = beta
12+
self.eps = eps
13+
14+
@torch.no_grad()
15+
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
16+
batch_size = pd_scores.size(0)
17+
num_max_boxes = gt_bboxes.size(1)
18+
19+
if num_max_boxes == 0:
20+
device = gt_bboxes.device
21+
return (
22+
torch.full_like(pd_scores[..., 0], self.nc).to(device),
23+
torch.zeros_like(pd_bboxes).to(device),
24+
torch.zeros_like(pd_scores).to(device),
25+
torch.zeros_like(pd_scores[..., 0]).to(device),
26+
torch.zeros_like(pd_scores[..., 0]).to(device),
27+
)
28+
29+
num_anchors = anc_points.shape[0]
30+
shape = gt_bboxes.shape
31+
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
32+
mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2)
33+
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
34+
na = pd_bboxes.shape[-2]
35+
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
36+
overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
37+
bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
38+
39+
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
40+
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
41+
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
42+
bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask] # b, max_num_obj, h*w
43+
44+
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask]
45+
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask]
46+
overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0)
47+
48+
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
49+
50+
top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool()
51+
top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True)
52+
if top_k_mask is None:
53+
top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices)
54+
top_k_indices.masked_fill_(~top_k_mask, 0)
55+
56+
mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device)
57+
ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device)
58+
for k in range(self.top_k):
59+
mask_top_k.scatter_add_(-1, top_k_indices[:, :, k : k + 1], ones)
60+
mask_top_k.masked_fill_(mask_top_k > 1, 0)
61+
mask_top_k = mask_top_k.to(align_metric.dtype)
62+
mask_pos = mask_top_k * mask_in_gts * mask_gt
63+
64+
fg_mask = mask_pos.sum(-2)
65+
if fg_mask.max() > 1:
66+
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1)
67+
max_overlaps_idx = overlaps.argmax(1)
68+
69+
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
70+
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
71+
72+
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
73+
fg_mask = mask_pos.sum(-2)
74+
target_gt_idx = mask_pos.argmax(-2)
75+
76+
# Assigned target
77+
index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None]
78+
target_index = target_gt_idx + index * num_max_boxes
79+
target_labels = gt_labels.long().flatten()[target_index]
80+
81+
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index]
82+
83+
# Assigned target scores
84+
target_labels.clamp_(0)
85+
86+
target_scores = torch.zeros(
87+
(target_labels.shape[0], target_labels.shape[1], self.nc), dtype=torch.int64, device=target_labels.device
88+
)
89+
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
90+
91+
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc)
92+
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
93+
94+
# Normalize
95+
align_metric *= mask_pos
96+
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
97+
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
98+
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
99+
target_scores = target_scores * norm_align_metric
100+
101+
return target_bboxes, target_scores, fg_mask.bool()
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from torch import nn
2+
import torch
3+
from torch.nn import functional as F
4+
from evo_science.entities.metrics.iou import IOU
5+
6+
7+
class BoxLoss(nn.Module):
8+
def __init__(self, dfl_ch):
9+
super().__init__()
10+
self.dfl_ch = dfl_ch
11+
12+
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
13+
loss_iou = self._compute_iou_loss(pred_bboxes, target_bboxes, target_scores, target_scores_sum, fg_mask)
14+
loss_dfl = self._compute_dfl_loss(
15+
pred_dist, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
16+
)
17+
return loss_iou, loss_dfl
18+
19+
def _compute_iou_loss(self, pred_bboxes, target_bboxes, target_scores, target_scores_sum, fg_mask):
20+
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
21+
iou = IOU.compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
22+
return ((1.0 - iou) * weight).sum() / target_scores_sum
23+
24+
def _compute_dfl_loss(self, pred_dist, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
25+
target = self._prepare_dfl_target(anchor_points, target_bboxes)
26+
loss_dfl = self._distribution_focal_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask])
27+
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
28+
return (loss_dfl * weight).sum() / target_scores_sum
29+
30+
def _prepare_dfl_target(self, anchor_points, target_bboxes):
31+
a, b = target_bboxes.chunk(2, -1)
32+
target = torch.cat((anchor_points - a, b - anchor_points), -1)
33+
return target.clamp(0, self.dfl_ch - 0.01)
34+
35+
@staticmethod
36+
def _distribution_focal_loss(pred_dist, target):
37+
tl = target.long()
38+
tr = tl + 1
39+
wl = tr - target
40+
wr = 1 - wl
41+
left_loss = F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape)
42+
right_loss = F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape)
43+
return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True)

0 commit comments

Comments
 (0)