Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed app/models/models/best_model_0823.pt
Binary file not shown.
Binary file added app/models/models/best_model_0920.pt
Binary file not shown.
61 changes: 41 additions & 20 deletions app/services/predictor_service.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import transforms
from PIL import Image
import json
from pathlib import Path
from io import BytesIO

class LightCNN(nn.Module):
def __init__(self, num_classes):
class EfficientNetBaseline(nn.Module):
def __init__(self, num_classes, pretrained=True, dropout=0.2):
super().__init__()
self.conv1 = nn.Conv2d(3, 8, 3, padding=1)
self.conv2 = nn.Conv2d(8, 16, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.gap = nn.AdaptiveAvgPool2d((4, 4))
self.fc1 = nn.Linear(16 * 4 * 4, 64)
self.fc2 = nn.Linear(64, num_classes)
self.backbone = timm.create_model(
"efficientnet_b3", pretrained=pretrained, num_classes=0, global_pool="avg"
)
feat_dim = self.backbone.num_features
self.bn = nn.BatchNorm1d(feat_dim)
self.dp = nn.Dropout(dropout)
self.fc = nn.Linear(feat_dim, num_classes)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.gap(x)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
feats = self.backbone(x)
feats = self.bn(feats)
feats = self.dp(feats)
logits = self.fc(feats)
return logits

class PredictorService:
def __init__(self, model_path: Path, json_path: Path):
Expand All @@ -47,10 +46,32 @@ def _load_idx2label(self, json_path: Path) -> dict:
idx2label = {str(label): f"K-{label:06d}" for label in unique_labels}
return idx2label

def _load_model(self, model_path: Path) -> LightCNN:
def _load_model(self, model_path: Path) -> EfficientNetBaseline:
import __main__
__main__.LightCNN = LightCNN
model = torch.load(model_path, map_location=self.device, weights_only=False)
__main__.EfficientNetBaseline = EfficientNetBaseline

object = torch.load(model_path, map_location=self.device, weights_only=False)

if isinstance(object, nn.Module) :
# 그 자체로 모델일 때
model = object.to(self.device)
elif isinstance(object, dict) :
# 반환 타입이 state_dict
state_dict = object
for k in ['state_dict', 'model_state_dict', 'model']:
if k in object and isinstance(object[k], dict):
state_dict = object[k]
break

model = EfficientNetBaseline(self.num_classes).to(self.device)

missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing or unexpected:
print(f"[load_state_dict] missing keys: {missing}, unexpected keys: {unexpected}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

print 문을 사용하여 로그를 남기는 것보다 logging 모듈을 사용하는 것이 좋습니다. 서비스 환경에서는 logging을 통해 로그 레벨 관리, 포맷팅, 핸들러 설정 등 더 체계적인 로깅이 가능합니다. 파일 상단에 import logging을 추가하고, 이 부분은 logging.warning(...)을 사용하여 경고 메시지를 기록하는 것을 권장합니다.

Suggested change
print(f"[load_state_dict] missing keys: {missing}, unexpected keys: {unexpected}")
import logging
logging.warning(f"[load_state_dict] missing keys: {missing}, unexpected keys: {unexpected}")

else:
# type 일치하지 않음
raise TypeError(f"Unsupported checkpoint type: {type(object)}")

model.eval()
return model

Expand All @@ -70,7 +91,7 @@ def predict(self, stream_file: BytesIO) -> tuple[str, str, float]:


HERE = Path(__file__).resolve().parent.parent
MODEL_PATH = HERE / "models" / "models" / "best_model_0823.pt"
MODEL_PATH = HERE / "models" / "models" / "best_model_0920.pt"
JSON_PATH = HERE / "models" / "models" / "matched_all.json"

predictor_service = PredictorService(MODEL_PATH, JSON_PATH)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ torch==2.8.0
torchvision==0.23.0
Pillow==11.3.0
dotenv
openai
openai
timm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

재현 가능한 빌드를 위해 의존성 버전을 고정하는 것이 좋습니다. timm 라이브러리의 버전을 명시해주세요 (예: timm==0.9.16). 또한, 파일의 마지막에 개행 문자를 추가하는 것이 일반적인 관례입니다.

timm==0.9.16