-
Notifications
You must be signed in to change notification settings - Fork 35
/
tagging_model.py
35 lines (30 loc) · 1.09 KB
/
tagging_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from ram.models import ram
class TaggingModule(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
import gc
self.device = device
image_size = 384
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# load RAM Model
self.ram = ram(
pretrained='checkpoints/ram_swin_large_14m.pth',
image_size=image_size,
vit='swin_l'
).eval().to(device)
print('==> Tagging Module Loaded.')
gc.collect()
@torch.no_grad()
def forward(self, original_image):
print('==> Tagging...')
img = self.transform(original_image).unsqueeze(0).to(self.device)
tags, tags_chinese = self.ram.generate_tag(img)
print('==> Tagging results: {}'.format(tags[0]))
return [tag for tag in tags[0].split(' | ')]