diff --git a/requirements.txt b/requirements.txt index 8ad6b4f..2f5d0b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,94 +1,12 @@ -aiofiles==23.2.1 -aiosqlite==0.20.0 -alembic==1.13.1 -annotated-types==0.6.0 -anyio==4.3.0 -certifi==2024.2.2 -charset-normalizer==3.3.2 -click==8.1.7 -coloredlogs==15.0.1 -contourpy==1.2.1 -cycler==0.12.1 -fastapi==0.110.0 -filelock==3.14.0 -flatbuffers==24.3.25 -fonttools==4.51.0 -fsspec==2024.3.1 -gitdb==4.0.11 -GitPython==3.1.41 -greenlet==3.0.3 -h11==0.14.0 -humanfriendly==10.0 -idna==3.6 -Jinja2==3.1.4 -kiwisolver==1.4.5 -Mako==1.3.2 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -matplotlib==3.8.4 -mdurl==0.1.2 -mkdocs-git-committers-plugin-2==2.2.3 -mkdocs-git-revision-date-localized-plugin==1.2.4 -mkdocs-material==9.5.6 -mkdocs-material-extensions==1.3.1 -mpmath==1.3.0 -mypy==1.9.0 -mypy-extensions==1.0.0 -networkx==3.3 -numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.4.127 -nvidia-nvtx-cu12==12.1.105 -nvidia-tensorrt==99.0.0 -onnx==1.16.0 -onnxruntime-gpu==1.18.0 -onnxsim==0.4.36 -opencv-python==4.9.0.80 -packaging==24.0 -pandas==2.2.2 -pillow==10.3.0 -platformdirs==4.2.0 -protobuf==5.26.1 -psutil==5.9.8 -py-cpuinfo==9.0.0 -pydantic==2.6.4 -pydantic_core==2.16.3 -Pygments==2.18.0 -pyparsing==3.1.2 -python-dateutil==2.9.0.post0 -python-multipart==0.0.9 -pytz==2024.1 -PyYAML==6.0.1 -requests==2.31.0 -rich==13.7.1 -scipy==1.13.0 -seaborn==0.13.2 -six==1.16.0 -smmap==5.0.1 -sniffio==1.3.1 -SQLAlchemy==2.0.28 -starlette==0.36.3 -sympy==1.12 -tensorrt==10.0.1 -tensorrt-cu12==10.0.1 -tensorrt-cu12-bindings==10.0.1 -tensorrt-cu12-libs==10.0.1 -thop==0.1.1.post2209072238 -torch==2.3.0 -torchvision==0.18.0 -tqdm==4.66.4 -triton==2.3.0 -typing_extensions==4.10.0 -tzdata==2024.1 -ultralytics==8.2.14 -urllib3==2.2.1 -uvicorn==0.29.0 +fastapi~=0.103.0 +sqlalchemy~=2.0.25 +starlette~=0.27.0 +pydantic~=1.10.12 +pillow~=11.0.0 +aiofiles~=22.1.0 +opencv-python~=4.9.0.80 +numpy~=1.26.4 +aiohttp~=3.10.10 +sahi~=0.11.18 +ultralytics~=8.3.28 +alembic~=1.8.1 \ No newline at end of file diff --git a/src/detect.py b/src/detect.py index be33235..84deff2 100644 --- a/src/detect.py +++ b/src/detect.py @@ -1,21 +1,29 @@ +import asyncio +import concurrent.futures +import copy import io +import json import logging import os import uuid from datetime import datetime +from typing import List import PIL import aiofiles +import cv2 +import ffmpeg +import google.generativeai as genai import numpy as np +import aiohttp from PIL import Image, ImageDraw -from fastapi import APIRouter, UploadFile, File, Depends, HTTPException +from fastapi import APIRouter, UploadFile, File, Depends, HTTPException, BackgroundTasks +from google.ai.generativelanguage_v1beta.types import content +from sahi import AutoDetectionModel +from sahi.predict import get_sliced_prediction from sqlalchemy import select, update, delete from starlette.concurrency import run_in_threadpool from ultralytics import YOLO -from sahi import AutoDetectionModel -from sahi.predict import get_sliced_prediction -import asyncio -import concurrent.futures from sql_app.db import AsyncDBSession from sql_app.model.Model import Model @@ -23,9 +31,6 @@ from sql_app.model.User import User from user import token_verify -import cv2 -import ffmpeg - detection_router = APIRouter(prefix="/detect", tags=['detect']) limit = 0.2 # seconds @@ -34,19 +39,12 @@ detect_process_pool = concurrent.futures.ProcessPoolExecutor(max_workers=2) # each ingredient's confidence threshold -confidence_filter = { - "mushroom": 0.85, - "okra":0.75, - "heim":0.85, - "beef":0.4, - "chicken":0.4, - "pork":0.4, - "noodle":0.85, - "carrot":0.5, - "common" :0.65 # the ingridient which is not in the filter -} - -def result_processing(img,results): # results:[{"key":[(x1,x2,y1,y2)]}] +confidence_filter = {"mushroom": 0.85, "okra": 0.75, "heim": 0.85, "beef": 0.4, "chicken": 0.4, "pork": 0.4, + "noodle": 0.85, "carrot": 0.5, "common": 0.65 # the ingridient which is not in the filter + } + + +def result_processing(img, results): # results:[{"key":[(x1,x2,y1,y2)]}] path = {} for key, points in results.items(): img_cp = img.copy() @@ -58,23 +56,20 @@ def result_processing(img,results): # results:[{"key":[(x1,x2,y1,y2)]}] path[key] = random_name img_cp.save(random_name) return path - -def image_processing(image,model_path): - model = AutoDetectionModel.from_pretrained( - model_type="yolov8", - device="cuda:0", - model_path=model_path, - ) - results = get_sliced_prediction(image, model, slice_height=500, slice_width=500, overlap_height_ratio=0.2, overlap_width_ratio=0.2) + +def image_processing(image, model_path): + model = AutoDetectionModel.from_pretrained(model_type="yolov8", device="cuda:0", model_path=model_path, ) + results = get_sliced_prediction(image, model, slice_height=500, slice_width=500, overlap_height_ratio=0.2, + overlap_width_ratio=0.2) result = {} - print("%-5s %-20s %-15s %s %s" % ("index", "name", "confidence", "status","threshold")) + print("%-5s %-20s %-15s %s %s" % ("index", "name", "confidence", "status", "threshold")) for idx, det in enumerate(results.object_prediction_list): - show_log = lambda status, th:print("%-5d %-20s %.13f %-6s %s" % (idx, det.category.name, det.score.value, status, th)) - + show_log = lambda status, th: print( + "%-5d %-20s %.13f %-6s %s" % (idx, det.category.name, det.score.value, status, th)) if det.category.name not in confidence_filter: threshold = confidence_filter['common'] @@ -95,7 +90,7 @@ def image_processing(image,model_path): @detection_router.post("/latest/img") -async def detect_image(db: AsyncDBSession, image: UploadFile = File(...)) -> dict[str,str]: +async def detect_image(db: AsyncDBSession, image: UploadFile = File(...)) -> dict[str, str]: stmt = select(Model).order_by(Model.update_date.desc()).limit(1) result = (await db.execute(stmt)).scalars().first() @@ -156,8 +151,7 @@ async def detect_image(version: str, db: AsyncDBSession, image: UploadFile = Fil @detection_router.post("/upload/pt") -async def upload_pt(db: AsyncDBSession, description: str, version: str, - user: User = Depends(token_verify), +async def upload_pt(db: AsyncDBSession, description: str, version: str, user: User = Depends(token_verify), pt: UploadFile = File(...)): if user.level <= 127: raise HTTPException(status_code=401, detail='You are not administrator') @@ -188,12 +182,9 @@ async def upload_pt(db: AsyncDBSession, description: str, version: str, if result: path: str = result.file_path os.remove(path) - stmt = (update(Model).where(Model.version == version).values( - file_path=filename, - description=description, - size=size, - update_date=datetime.now() - )) + stmt = ( + update(Model).where(Model.version == version).values(file_path=filename, description=description, size=size, + update_date=datetime.now())) try: await db.execute(stmt) await db.commit() @@ -202,13 +193,8 @@ async def upload_pt(db: AsyncDBSession, description: str, version: str, raise e else: - model_information = Model( - file_path=filename, - description=description, - size=size, - version=version, - update_date=datetime.now() - ) + model_information = Model(file_path=filename, description=description, size=size, version=version, + update_date=datetime.now()) try: db.add(model_information) @@ -276,8 +262,8 @@ async def delete_version(version: str, db: AsyncDBSession, user: User = Depends( def video_processing(model, filename): - - ffmpeg.input(f"{filename}").filter('fps', fps=30).filter('scale', height='1080', width='-2').output(f"{filename}-convert.mp4").run() + ffmpeg.input(f"{filename}").filter('fps', fps=30).filter('scale', height='1080', width='-2').output( + f"{filename}-convert.mp4").run() cap = cv2.VideoCapture(f"{filename}-convert.mp4") framerate = cap.get(cv2.CAP_PROP_FPS) @@ -355,7 +341,7 @@ async def detect_video(db: AsyncDBSession, video: UploadFile = File(...)): time_old = datetime.now() loop = asyncio.get_event_loop() - result = await loop.run_in_executor(detect_process_pool,video_processing, model, filename) + result = await loop.run_in_executor(detect_process_pool, video_processing, model, filename) logger.info(f"take {datetime.now() - time_old} to detect video") return result @@ -378,3 +364,173 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil result = await run_in_threadpool(video_processing, model, filename) logger.info(f"Take {datetime.now() - time_old} to detect video") return result + + +# The code to implement the + +# GOOGLE API KEY +GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") +BASE_URL = "https://generativelanguage.googleapis.com" + +genai.configure(api_key=GOOGLE_API_KEY) + +chat_session = None +list_of_i = [] +ingredients_set = set() +lock = asyncio.Lock() + + +# If ingredients change, server will need to restart to update the chat session +async def get_chat_session(db:AsyncDBSession): + + async with lock: + global chat_session, list_of_i, ingredients_set + + result = await db.execute(select(Ingredient.name)) + ingredients = [i.name for i in result] + + if list_of_i != ingredients: + chat_session = None # reset chat session + list_of_i = ingredients + ingredients_set = set(ingredients) + if chat_session is None: + await run_in_threadpool(init_session, ingredients) + return copy.deepcopy(chat_session) + + +def init_session(ingredients): + global chat_session + + generation_config = { + "temperature": 1, + "top_p": 0.95, + "top_k": 40, + "max_output_tokens": 8192, + "response_schema": content.Schema(type=content.Type.OBJECT, properties={ + "result": content.Schema(type=content.Type.ARRAY,items=content.Schema(type=content.Type.STRING)), + }), + "response_mime_type": "application/json", + } + + model = genai.GenerativeModel(model_name="gemini-1.5-pro", generation_config=generation_config, + system_instruction="Detect food in image (but not text) and map to the name I provided.", ) + + history = [{ + "role": "user", + "parts": [ + "Detect these shape of food which is on the list " + + str(ingredients) + " on these images down below," + + "than map the detected ingredients to the name i provided as an array with field name \"result\". " + ], + }] + + chat_session = model.start_chat(history=history) + + +# This function only supports .jpg or .png files, +# and if the file is not a .jpg or .png file, it will +async def upload2gemini(path): + # Using Pillow to check if the file is a .jpg or .png file + try: + with PIL.Image.open(path) as img: + if img.format not in ['JPEG', 'PNG']: + raise HTTPException(status_code=415, detail='This endpoint only support jpg or png file') + except PIL.UnidentifiedImageError: + raise HTTPException(status_code=415, detail='We can\'t recognize the file type') + + size_of_file = os.path.getsize(path) + + headers = { + 'X-Goog-Upload-Protocol': 'resumable', + 'X-Goog-Upload-Command': 'start', + 'Content-Type': 'application/json', + 'X-Goog-Upload-Header-Content-Length': str(size_of_file) # header only support str, not int + } + + params = {'key': GOOGLE_API_KEY} + + filename = os.path.basename(path) + + data = "{'file': {'display_name': '" + filename + "'}}" + + # response = requests.post(f'{BASE_URL}/upload/v1beta/files', params=params, headers=headers, data=data) + async with aiohttp.ClientSession() as session: + async with session.post(f'{BASE_URL}/upload/v1beta/files', params=params, headers=headers, data=data) as response: + if response.status != 200: + raise HTTPException(status_code=response.status_code, + detail="Error when upload image to gemini server. details: " + str(await response.json())) + + upload_url = response.headers["x-goog-upload-url"] + + headers = { + 'X-Goog-Upload-Offset': '0', + 'X-Goog-Upload-Command': 'upload, finalize', + 'Content-Type': 'application/x-www-form-urlencoded' + } + + async with aiofiles.open(path, 'rb') as file: + data = await file.read() + + # res = requests.post(f'{upload_url}', headers=headers, data=data).json()['file'] + async with session.post(f'{upload_url}', headers=headers, data=data) as response: + if response.status != 200: + raise HTTPException(status_code=response.status_code, + detail="Error when upload image to gemini server. details: " + str(await response.json())) + res = (await response.json())['file'] + + args = { + 'name': res['name'], + 'display_name': res['displayName'], + 'mime_type': res['mimeType'], + 'sha256_hash': res['sha256Hash'], + 'size_bytes': res['sizeBytes'], + 'state': res['state'], + 'uri': res['uri'], + 'create_time': res['createTime'], + 'expiration_time': res['expirationTime'], + 'update_time': res['updateTime'] + } + + return genai.types.File(args) + + +def detect_files(session, file): + response = session.send_message(file) + output = json.loads(response.text) + result = [] + + # ensure all the ingredients in output are on the set of ingredients + for i in output['result']: + if i in ingredients_set: + result.append(i) + else: + print(f"Warning: {i} is not in the list of ingredients.") + + return result + + + + +@detection_router.post("/gemini") +async def detect_by_gemini(background:BackgroundTasks,session = Depends(get_chat_session),files: List[UploadFile] = File(...)) -> List[str]: + """ + Detect the ingredients in the image by using the gemini API + + :param files: List of images + """ + + upload_coroutine = [] + + for file in files: + random_name = uuid.uuid4().hex + async with aiofiles.open(f"./img/{random_name}", 'wb') as output_file: + await output_file.write(await file.read()) + upload_coroutine.append(upload2gemini(f"./img/{random_name}")) + + upload_files = await asyncio.gather(*upload_coroutine) + + response = await run_in_threadpool(detect_files, session, upload_files) + for i in upload_files: + background.add_task(run_in_threadpool,i.delete) + + return response