-
Notifications
You must be signed in to change notification settings - Fork 0
DEPLOY - v1.0.1 #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DEPLOY - v1.0.1 #28
Changes from all commits
5e01979
dc3300a
9e86cb0
647aed8
893ce40
080301c
bea947a
20702ca
7b3cedf
36fc45f
9d767fe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| from logging.config import dictConfig | ||
|
|
||
| """ | ||
| μ ν리μΌμ΄μ μ 체μμμ λ‘κΉ λ 벨 μ€μ | ||
| """ | ||
| def setup_logging(): | ||
| dictConfig({ | ||
| "version": 1, | ||
| "disable_existing_loggers": False, | ||
| "formatters": { | ||
| "default": { | ||
| "format": "%(asctime)s %(levelname)s [%(name)s] %(message)s", | ||
| }, | ||
| }, | ||
| "handlers": { | ||
| "console": { | ||
| "class": "logging.StreamHandler", | ||
| "formatter": "default", | ||
| }, | ||
| }, | ||
| "root": { | ||
| "level": "INFO", | ||
| "handlers": ["console"], | ||
| }, | ||
| "loggers": { | ||
| "uvicorn": {"level": "INFO"}, | ||
| "uvicorn.error": {"level": "INFO"}, | ||
| "uvicorn.access": {"level": "INFO"}, | ||
| "app": {"level": "INFO", "handlers": ["console"], "propagate": False}, | ||
| }, | ||
| }) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,76 +1,152 @@ | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import timm | ||
| from torchvision import transforms | ||
| from PIL import Image | ||
| import json | ||
| from pathlib import Path | ||
| from io import BytesIO | ||
| import logging | ||
|
|
||
| class LightCNN(nn.Module): | ||
| def __init__(self, num_classes): | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class EfficientNetBaseline(nn.Module): | ||
| def __init__(self, num_classes, pretrained=True, dropout=0.2): | ||
| super().__init__() | ||
| self.conv1 = nn.Conv2d(3, 8, 3, padding=1) | ||
| self.conv2 = nn.Conv2d(8, 16, 3, padding=1) | ||
| self.pool = nn.MaxPool2d(2, 2) | ||
| self.gap = nn.AdaptiveAvgPool2d((4, 4)) | ||
| self.fc1 = nn.Linear(16 * 4 * 4, 64) | ||
| self.fc2 = nn.Linear(64, num_classes) | ||
| logger.info( | ||
| f"Initializing EfficientNetBaseline with num_classes={num_classes}, pretrained={pretrained}, dropout={dropout}") | ||
|
|
||
| self.backbone = timm.create_model( | ||
| "efficientnet_b3", pretrained=pretrained, num_classes=0, global_pool="avg" | ||
| ) | ||
| feat_dim = self.backbone.num_features | ||
|
|
||
| self.bn = nn.BatchNorm1d(feat_dim) | ||
| self.dp = nn.Dropout(dropout) | ||
| self.fc = nn.Linear(feat_dim, num_classes) | ||
|
|
||
| def forward(self, x): | ||
| x = self.pool(F.relu(self.conv1(x))) | ||
| x = self.pool(F.relu(self.conv2(x))) | ||
| x = self.gap(x) | ||
| x = x.view(x.size(0), -1) | ||
| x = F.relu(self.fc1(x)) | ||
| x = self.fc2(x) | ||
| return x | ||
| feats = self.backbone(x) | ||
| feats = self.bn(feats) | ||
| feats = self.dp(feats) | ||
| logits = self.fc(feats) | ||
| return logits | ||
|
|
||
| class PredictorService: | ||
| def __init__(self, model_path: Path, json_path: Path): | ||
|
|
||
| logger.info(f"Initializing PredictorService with model_path={model_path}, json_path={json_path}") | ||
|
|
||
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| # cuda λͺ¨λΈ νμΈ | ||
| logger.info(f"Using device: {self.device}") | ||
|
|
||
| self.idx2label = self._load_idx2label(json_path) | ||
| self.num_classes = len(self.idx2label) | ||
| logger.info(f"Loaded {self.num_classes} classes") | ||
|
|
||
| self.model = self._load_model(model_path) | ||
| self.transform = transforms.Compose([ | ||
| transforms.Resize((64, 64)), | ||
| transforms.ToTensor(), | ||
| ]) | ||
| logger.info("PredictorService initialized successfully") | ||
|
|
||
| def _load_idx2label(self, json_path: Path) -> dict: | ||
| with open(json_path, "r", encoding="utf-8") as f: | ||
| data = json.load(f) | ||
| idx2label = data.get("idx2label") | ||
| if not idx2label: | ||
| unique_labels = sorted(set(sample["label"] for sample in data["samples"])) | ||
| idx2label = {str(label): f"K-{label:06d}" for label in unique_labels} | ||
| return idx2label | ||
|
|
||
| def _load_model(self, model_path: Path) -> LightCNN: | ||
|
|
||
| # json μ λλ‘ μ½μλμ§ νμΈ | ||
| logger.info(f"Loading idx2label from {json_path}") | ||
|
|
||
| try: | ||
| with open(json_path, "r", encoding="utf-8") as f: | ||
| data = json.load(f) | ||
| idx2label = data.get("idx2label") | ||
| if not idx2label: | ||
| logger.warning("idx2label not found in JSON, generating from samples") | ||
| unique_labels = sorted(set(sample["label"] for sample in data["samples"])) | ||
| idx2label = {str(label): f"K-{label:06d}" for label in unique_labels} | ||
| return idx2label | ||
|
|
||
| # μμΈ μ¬ν μΆκ° | ||
| except FileNotFoundError: | ||
| logger.error(f"JSON file not found: {json_path}") | ||
| raise | ||
| except json.JSONDecodeError as e: | ||
| logger.error(f"Failed to parse JSON file: {e}") | ||
| raise | ||
| except Exception as e: | ||
| logger.error(f"Unexpected error loading idx2label: {e}", exc_info=True) | ||
| raise | ||
|
|
||
| def _load_model(self, model_path: Path) -> EfficientNetBaseline: | ||
| # model path νμΈνκΈ° | ||
| logger.info(f"Loading model from {model_path}") | ||
|
|
||
| import __main__ | ||
| __main__.LightCNN = LightCNN | ||
| model = torch.load(model_path, map_location=self.device, weights_only=False) | ||
| model.eval() | ||
| return model | ||
| __main__.EfficientNetBaseline = EfficientNetBaseline | ||
|
|
||
| try: | ||
| object = torch.load(model_path, map_location=self.device, weights_only=False) | ||
|
|
||
| if isinstance(object, nn.Module) : | ||
| # κ·Έ μμ²΄λ‘ λͺ¨λΈμΌ λ | ||
| model = object.to(self.device) | ||
| elif isinstance(object, dict) : | ||
| # λ°ν νμ μ΄ state_dict | ||
| state_dict = object | ||
| for k in ['state_dict', 'model_state_dict', 'model']: | ||
| if k in object and isinstance(object[k], dict): | ||
| state_dict = object[k] | ||
| break | ||
|
|
||
| model = EfficientNetBaseline(self.num_classes).to(self.device) | ||
|
|
||
| missing, unexpected = model.load_state_dict(state_dict, strict=False) | ||
| if missing or unexpected: | ||
| logger.warning(f"[load_state_dict] missing keys: {missing}, unexpected keys: {unexpected}") | ||
| else: | ||
| # type μΌμΉνμ§ μμ | ||
| error_msg = f"Unsupported checkpoint type: {type(object)}" | ||
| logger.error(error_msg) | ||
| raise TypeError(f"Unsupported checkpoint type: {type(object)}") | ||
|
|
||
| # μλ£ | ||
| model.eval() | ||
| logger.info("Model loaded and set to evaluation mode") | ||
| return model | ||
|
|
||
| except FileNotFoundError: | ||
| logger.error(f"Model file not found: {model_path}") | ||
| raise | ||
| except Exception as e: | ||
| logger.error(f"Failed to load model: {e}", exc_info=True) | ||
| raise | ||
|
|
||
|
|
||
| def predict(self, stream_file: BytesIO) -> tuple[str, str, float]: | ||
| image = Image.open(stream_file).convert('RGB') | ||
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) | ||
|
|
||
| with torch.no_grad(): | ||
| output = self.model(input_tensor) | ||
| predicted_idx = torch.argmax(output, dim=1).item() | ||
| confidence = torch.softmax(output, dim=1)[0][predicted_idx].item() | ||
| try : | ||
| image = Image.open(stream_file).convert('RGB') | ||
| input_tensor = self.transform(image).unsqueeze(0).to(self.device) | ||
|
|
||
| with torch.no_grad(): | ||
| output = self.model(input_tensor) | ||
| predicted_idx = torch.argmax(output, dim=1).item() | ||
| confidence = torch.softmax(output, dim=1)[0][predicted_idx].item() | ||
|
|
||
| label = str(predicted_idx) | ||
| pill_name = self.idx2label.get(label, f"Unknown Label: {label}") | ||
| label = str(predicted_idx) | ||
| pill_name = self.idx2label.get(label, f"Unknown Label: {label}") | ||
| logger.info(f"Prediction completed - pill_name: {pill_name}, label: {label}, confidence: {confidence:.4f}") | ||
|
|
||
| return pill_name, label, confidence | ||
| return pill_name, label, confidence | ||
| except Exception as e: | ||
| logger.error(f"Prediction failed: {e}", exc_info=True) | ||
| raise | ||
|
|
||
|
|
||
| HERE = Path(__file__).resolve().parent.parent | ||
| MODEL_PATH = HERE / "models" / "models" / "best_model_0823.pt" | ||
| MODEL_PATH = HERE / "models" / "models" / "best_model_0920.pt" | ||
| JSON_PATH = HERE / "models" / "models" / "matched_all.json" | ||
|
|
||
| predictor_service = PredictorService(MODEL_PATH, JSON_PATH) | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,6 +4,9 @@ | |||||
| from app.core.config import settings | ||||||
| from app.schemas.job import ImageJob | ||||||
| from app.worker.tasks import process_image_scan | ||||||
| import logging | ||||||
|
|
||||||
| logger = logging.getLogger(__name__) | ||||||
|
|
||||||
| """ | ||||||
| Redis Streamμ μ μν μ ν¨ν νμ λ©μμ§λ₯Ό μν μ μ²λ¦¬ ν¨μ | ||||||
|
|
@@ -49,7 +52,7 @@ def __init__(self, redis_client: redis_client): | |||||
| self.redis_client = redis_client | ||||||
|
|
||||||
| async def run(self): | ||||||
| print(f"[worker] start consumer={settings.CONSUMER_NAME} group={settings.GROUP_NAME} stream={settings.STREAM_JOB}") | ||||||
| logging.info(f"[worker] start consumer={settings.CONSUMER_NAME} group={settings.GROUP_NAME} stream={settings.STREAM_JOB}") | ||||||
| reclaim_every_sec = 30 | ||||||
| last_reclaim = 0.0 | ||||||
|
|
||||||
|
|
@@ -82,12 +85,12 @@ async def run(self): | |||||
|
|
||||||
| # μ΅μ’ λ°ν data | ||||||
| data = json.loads(payload_str) | ||||||
| print(f"Job received id={msg_id} correlationId={correlation_id} payload={data}") | ||||||
| logging.info(f"Job received id={msg_id} correlationId={correlation_id} payload={data}") | ||||||
|
|
||||||
| job = ImageJob.model_validate(data) | ||||||
| # XADDκΉμ§ νΈμΆ | ||||||
| task = asyncio.create_task(process_image_scan(job, redis_client)) | ||||||
| print(f"[worker] {task} λ°ν μ±κ³΅") | ||||||
| logging.info(f"[worker] {task} λ°ν μ±κ³΅") | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| # μ²λ¦¬ μ±κ³΅ μμλ§ ack ν del | ||||||
| task.add_done_callback(lambda t: asyncio.create_task( | ||||||
|
|
@@ -132,7 +135,7 @@ async def run(self): | |||||
| job = ImageJob.model_validate_json(payload) | ||||||
|
|
||||||
| task = asyncio.create_task(process_image_scan(job, self.redis_client)) | ||||||
| print(f"[worker] {task} λ°ν μ±κ³΅") | ||||||
| logging.info(f"[worker] {task} λ°ν μ±κ³΅") | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| def _on_done(t: asyncio.Task, *, msg_id=msg_id, fields=fields): | ||||||
| async def _ack_or_dlq(): | ||||||
|
|
@@ -155,8 +158,8 @@ async def _ack_or_dlq(): | |||||
| await self.redis_client.xadd(f"{settings.STREAM_JOB}:DLQ", clean) | ||||||
|
|
||||||
| except asyncio.CancelledError: | ||||||
| print("[worker] cancelled; bye") | ||||||
| logging.warning("[worker] cancelled; bye") | ||||||
| break | ||||||
| except Exception as e: | ||||||
| print(f"[worker] error: {e}") | ||||||
| logging.warning(f"[worker] error: {e}") | ||||||
| await asyncio.sleep(1) | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.loadλ₯Όweights_only=Falseλ‘ νΈμΆνλ κ²μ 보μμ μνν©λλ€. μ΄ μ΅μ μ Pythonμpickleμ μ¬μ©νμ¬ κ°μ²΄λ₯Ό μμ§λ ¬ννλλ°, μ μμ μΌλ‘ μ‘°μλ λͺ¨λΈ νμΌ(.pt)μ λ‘λν κ²½μ° μμ μ½λκ° μ€νλ μ μμ΅λλ€. λͺ¨λΈμstate_dictλ§ μ μ₯νκ³ ,model.load_state_dict()λ₯Ό μ¬μ©νμ¬ κ°μ€μΉλ₯Ό λ‘λνλ κ²μ΄ ν¨μ¬ μμ ν©λλ€. νμ¬ μ½λμμ μ΄λ―Έstate_dictλ₯Ό μ²λ¦¬νλ λ‘μ§μ΄ μμΌλ―λ‘, μ΄ λ°©μμ νμ€μΌλ‘ μΌκ³weights_only=Falseμ¬μ©μ νΌνλ κ²μ κ°λ ₯ν κΆμ₯ν©λλ€.