From f8fcb590203da504ddee01c58252fe449d24ba20 Mon Sep 17 00:00:00 2001 From: bloodnighttw Date: Sun, 10 Nov 2024 01:11:37 +0800 Subject: [PATCH 1/6] add gemini endpoint. --- src/detect.py | 246 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 227 insertions(+), 19 deletions(-) diff --git a/src/detect.py b/src/detect.py index be33235..0c39069 100644 --- a/src/detect.py +++ b/src/detect.py @@ -1,12 +1,16 @@ +import copy import io +import json import logging import os import uuid from datetime import datetime +from typing import List import PIL import aiofiles import numpy as np +import requests from PIL import Image, ImageDraw from fastapi import APIRouter, UploadFile, File, Depends, HTTPException from sqlalchemy import select, update, delete @@ -26,6 +30,9 @@ import cv2 import ffmpeg +import google.generativeai as genai +from google.ai.generativelanguage_v1beta.types import content + detection_router = APIRouter(prefix="/detect", tags=['detect']) limit = 0.2 # seconds @@ -36,17 +43,18 @@ # each ingredient's confidence threshold confidence_filter = { "mushroom": 0.85, - "okra":0.75, - "heim":0.85, - "beef":0.4, - "chicken":0.4, - "pork":0.4, - "noodle":0.85, - "carrot":0.5, - "common" :0.65 # the ingridient which is not in the filter + "okra": 0.75, + "heim": 0.85, + "beef": 0.4, + "chicken": 0.4, + "pork": 0.4, + "noodle": 0.85, + "carrot": 0.5, + "common": 0.65 # the ingridient which is not in the filter } -def result_processing(img,results): # results:[{"key":[(x1,x2,y1,y2)]}] + +def result_processing(img, results): # results:[{"key":[(x1,x2,y1,y2)]}] path = {} for key, points in results.items(): img_cp = img.copy() @@ -58,23 +66,24 @@ def result_processing(img,results): # results:[{"key":[(x1,x2,y1,y2)]}] path[key] = random_name img_cp.save(random_name) return path - -def image_processing(image,model_path): + +def image_processing(image, model_path): model = AutoDetectionModel.from_pretrained( model_type="yolov8", device="cuda:0", model_path=model_path, ) - results = get_sliced_prediction(image, model, slice_height=500, slice_width=500, overlap_height_ratio=0.2, overlap_width_ratio=0.2) + results = get_sliced_prediction(image, model, slice_height=500, slice_width=500, overlap_height_ratio=0.2, + overlap_width_ratio=0.2) result = {} - print("%-5s %-20s %-15s %s %s" % ("index", "name", "confidence", "status","threshold")) + print("%-5s %-20s %-15s %s %s" % ("index", "name", "confidence", "status", "threshold")) for idx, det in enumerate(results.object_prediction_list): - show_log = lambda status, th:print("%-5d %-20s %.13f %-6s %s" % (idx, det.category.name, det.score.value, status, th)) - + show_log = lambda status, th: print( + "%-5d %-20s %.13f %-6s %s" % (idx, det.category.name, det.score.value, status, th)) if det.category.name not in confidence_filter: threshold = confidence_filter['common'] @@ -95,7 +104,7 @@ def image_processing(image,model_path): @detection_router.post("/latest/img") -async def detect_image(db: AsyncDBSession, image: UploadFile = File(...)) -> dict[str,str]: +async def detect_image(db: AsyncDBSession, image: UploadFile = File(...)) -> dict[str, str]: stmt = select(Model).order_by(Model.update_date.desc()).limit(1) result = (await db.execute(stmt)).scalars().first() @@ -276,8 +285,8 @@ async def delete_version(version: str, db: AsyncDBSession, user: User = Depends( def video_processing(model, filename): - - ffmpeg.input(f"{filename}").filter('fps', fps=30).filter('scale', height='1080', width='-2').output(f"{filename}-convert.mp4").run() + ffmpeg.input(f"{filename}").filter('fps', fps=30).filter('scale', height='1080', width='-2').output( + f"{filename}-convert.mp4").run() cap = cv2.VideoCapture(f"{filename}-convert.mp4") framerate = cap.get(cv2.CAP_PROP_FPS) @@ -355,7 +364,7 @@ async def detect_video(db: AsyncDBSession, video: UploadFile = File(...)): time_old = datetime.now() loop = asyncio.get_event_loop() - result = await loop.run_in_executor(detect_process_pool,video_processing, model, filename) + result = await loop.run_in_executor(detect_process_pool, video_processing, model, filename) logger.info(f"take {datetime.now() - time_old} to detect video") return result @@ -378,3 +387,202 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil result = await run_in_threadpool(video_processing, model, filename) logger.info(f"Take {datetime.now() - time_old} to detect video") return result + + +# The code to implement the + +# GOOGLE API KEY +# GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") +GOOGLE_API_KEY = "AIzaSyAxOCHULsHUHe7W7b1HXP_NGmq2KT_hozc" +BASE_URL = "https://generativelanguage.googleapis.com" + +genai.configure(api_key=GOOGLE_API_KEY) + +chat_session = None +lock = asyncio.Lock() + + +# If ingredients change, server will need to restart to update the chat session +async def get_chat_session(ingredients): + async with lock: + global chat_session + if chat_session is None: + await run_in_threadpool(init_session, ingredients) + return copy.deepcopy(chat_session) + + +def init_session(ingredients): + global chat_session + + generation_config = { + "temperature": 1, + "top_p": 0.95, + "top_k": 40, + "max_output_tokens": 8192, + "response_schema": content.Schema( + type=content.Type.OBJECT, + properties={ + "result": content.Schema( + type=content.Type.ARRAY, + items=content.Schema( + type=content.Type.STRING, + ), + ), + }, + ), + "response_mime_type": "application/json", + } + + model = genai.GenerativeModel( + model_name="gemini-1.5-pro", + generation_config=generation_config, + system_instruction="Detect food in image (but not text) and map to the name I provided.", + ) + + history = [ + { + "role": "user", + "parts": [ + "Detect these shape of food which is on the list " + str( + ingredients) + "on these images down below, ,than map the detected ingredients to the name i provided as an array with field name \"result\". " + ], + }, + ] + + chat_session = model.start_chat( + history=history + ) + + +# This function only supports .jpg or .png files, +# and if the file is not a .jpg or .png file, it will return None +# Why rein +# TODO: make it async call +def upload2gemini(path): + # Using Pillow to check if the file is a .jpg or .png file + try: + with PIL.Image.open(path) as img: + if img.format not in ['JPEG', 'PNG']: + raise HTTPException(status_code=415, detail='This endpoint only support jpg or png file') + except PIL.UnidentifiedImageError: + return None + + size_of_file = os.path.getsize(path) + + headers = { + 'X-Goog-Upload-Protocol': 'resumable', + 'X-Goog-Upload-Command': 'start', + 'Content-Type': 'application/json', + 'X-Goog-Upload-Header-Content-Length': str(size_of_file), # header only support str, not int + } + + params = { + 'key': GOOGLE_API_KEY, + } + + filename = os.path.basename(path) + + data = "{'file': {'display_name': '" + filename + "'}}" + + response = requests.post( + f'{BASE_URL}/upload/v1beta/files', + params=params, + headers=headers, + data=data + ) + + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail="Error when upload image to gemini server. details: " + response.text) + + upload_url = response.headers["x-goog-upload-url"] + + headers = { + 'X-Goog-Upload-Offset': '0', + 'X-Goog-Upload-Command': 'upload, finalize', + 'Content-Type': 'application/x-www-form-urlencoded', + } + + with open(path, 'rb') as f: + data = f.read() + + res = requests.post(f'{upload_url}', headers=headers, data=data).json()['file'] + args = { + 'name': res['name'], + 'display_name': res['displayName'], + 'mime_type': res['mimeType'], + 'sha256_hash': res['sha256Hash'], + 'size_bytes': res['sizeBytes'], + 'state': res['state'], + 'uri': res['uri'], + 'create_time': res['createTime'], + 'expiration_time': res['expirationTime'], + 'update_time': res['updateTime'] + } + return genai.types.File(args) + +def detect_files(session,file,ingredients_sets): + response = session.send_message(file) + output = json.loads(response.text) + result = [] + + # ensure all the ingredients in output are on the set of ingredients + for i in output['result']: + if i in ingredients_sets: + result.append(i) + else: + print(f"Warning: {i} is not in the list of ingredients.") + + return result + +# TODO: depends on get_chat_session with auto fetching ingredients list +@detection_router.post("/gemini") +async def detect_by_gemini(files:List[UploadFile] = File(...)): + ingredients = ['asparagus', 'avocado', 'bamboo_shoots', 'beans_green', 'beetroot', 'cassava', 'chayote', 'cinnamon', + 'coriander', 'corn', 'egg', 'bean_mung', 'cabbage_napa', 'carrot', 'chicken', 'crab', 'garlic', + 'mint', 'pepper_bell', 'potato', 'chili', 'eggplant', 'gourd_bitter', 'gourd_bottle', + 'gourd_pointed', 'ham', 'jackfruit', 'lemon', 'mushroom_enoki', 'onion', 'pork', 'potato_sweet', + 'rice', 'almond', 'apple', 'artichoke', 'banana', 'blueberry', 'broccoli', 'broccoli_white', + 'mustard_greens', 'spinach', 'turnip', 'butter', 'cheese', 'milk', 'pasta', 'strawberry', + 'ash_gourd', 'beans_red', 'bokchoy', 'bread', 'brocolli_chinese', 'cabbage', 'cucumber', 'edamame', + 'fish', 'mushroom', 'noodle', 'okra', 'oyster', 'pumpkin', 'radish', 'seaweed', 'taro', 'tomato', + 'tomato_cherry', 'clam', 'burdock', 'peanut', 'spinach_water', 'leek', 'gourd_sponge', 'salmon', + 'apple_wax', 'chives', 'coconut', 'dragon_fruit', 'duck', 'durian', 'frog', 'ginger', 'grape', + 'guava', 'heim', 'kiwi', 'lettuce', 'mango', 'melon_water', 'orange', 'papaya', 'passion_fruit', + 'pineapple', 'potato_leaves', 'prawn', 'spinach_chinese', 'squid', 'tofu', 'zuccini', 'bean_green', + 'beef', 'melon_winter', 'lamb', 'lime', 'bean_sprout', 'tofu_dried', 'tofu_skin', 'ketchup', + 'truffle_sauce', 'miso', 'mayonnaise', 'scallop', 'oats', 'lotus_seed', 'goji', 'jujube', 'quinoa', + 'tomato_paste', 'tomato_can', 'sesame_sauce', 'century_egg', 'baby_corn', 'chili_bean_sauce', + 'basil', 'thyme', 'stokvis', 'sweet_bean_sauce', 'shallot', 'curry', 'yogurt', 'celery', 'stock', + 'sesame', 'soy_sauce', 'lobster', 'crabstick', 'tofu_puff', 'honey', 'yam', 'matcha', 'bean_soy', + 'kimchi', 'sugar_brown', 'egg_salted', 'bacon', 'cream_whip', 'tuna_can', 'paprika', + 'worcestershire_sauce', 'star_anise', 'tsaoko', 'clove', 'sichuan_pepper', 'lotus_root', + 'dried_shrimp', 'sesame_oil', 'mirin', 'sake', 'oyster_sauce', 'chinese_sauerkraut', 'chestnut', + 'shaoxing_wine', 'Chinese_spirits', 'bay_leaf', 'red_wine', 'konjac', 'fish_sauce', 'ginseng', + 'dried_clove_fish', 'bottle_gourd', 'dried_orange_peel', 'dry_beancurd_shreds', 'shacha_sauce', + 'pasta_sauce', 'rice_cake', 'flour', 'gochujang_sause', 'rice-wine', 'rosemary', 'bockwurst', + 'indian_buead', 'euryale_seed', 'coix_seed', 'chinese_angelica', 'longan', 'whisky', 'yeast', + 'sichuan_lovage_rhizome', 'radix_astragali', 'cmnamomi_mmulus', 'blood', 'nutmeg', 'dumpling_skin', + 'black_garlic', 'drinking_yogurt'] + session = await get_chat_session(ingredients) + + upload_coroutine = [] + + for file in files: + random_name = uuid.uuid4().hex + async with aiofiles.open(f"./img/{random_name}", 'wb') as output_file: + await output_file.write(await file.read()) + upload_coroutine.append(run_in_threadpool(upload2gemini, f"./img/{random_name}")) + + upload_files = await asyncio.gather(*upload_coroutine) + + response = await run_in_threadpool(detect_files,session,upload_files,set(ingredients)) + return response + + +# TODO: make it admin only. +@detection_router.post("/gemini/reset") +async def reset_chat_session(): + global chat_session + async with lock: + chat_session = None + return {'message': 'Reset success'} \ No newline at end of file From b16cca0fdb931a335ccf891620f353c31ceed806 Mon Sep 17 00:00:00 2001 From: bloodnighttw Date: Sun, 10 Nov 2024 01:13:17 +0800 Subject: [PATCH 2/6] Remove API KEY from code. --- src/detect.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/detect.py b/src/detect.py index 0c39069..0b298dd 100644 --- a/src/detect.py +++ b/src/detect.py @@ -392,8 +392,7 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil # The code to implement the # GOOGLE API KEY -# GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") -GOOGLE_API_KEY = "AIzaSyAxOCHULsHUHe7W7b1HXP_NGmq2KT_hozc" +GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") BASE_URL = "https://generativelanguage.googleapis.com" genai.configure(api_key=GOOGLE_API_KEY) From 1fc09c32aa769441dd43ad621a40c0b6235ada26 Mon Sep 17 00:00:00 2001 From: bloodnighttw Date: Sun, 10 Nov 2024 01:29:00 +0800 Subject: [PATCH 3/6] Remove API KEY from code and re-create key(then store them in env). --- src/detect.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/detect.py b/src/detect.py index 0b298dd..d41bd02 100644 --- a/src/detect.py +++ b/src/detect.py @@ -393,6 +393,7 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil # GOOGLE API KEY GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") +print(f"GOOGLE_API_KEY: {GOOGLE_API_KEY}") BASE_URL = "https://generativelanguage.googleapis.com" genai.configure(api_key=GOOGLE_API_KEY) @@ -454,7 +455,7 @@ def init_session(ingredients): # This function only supports .jpg or .png files, -# and if the file is not a .jpg or .png file, it will return None +# and if the file is not a .jpg or .png file, it will # Why rein # TODO: make it async call def upload2gemini(path): @@ -464,7 +465,7 @@ def upload2gemini(path): if img.format not in ['JPEG', 'PNG']: raise HTTPException(status_code=415, detail='This endpoint only support jpg or png file') except PIL.UnidentifiedImageError: - return None + raise HTTPException(status_code=415, detail='This endpoint is not the type that can be identified by Pillow') size_of_file = os.path.getsize(path) @@ -491,7 +492,7 @@ def upload2gemini(path): ) if response.status_code != 200: - raise HTTPException(status_code=response.status_code, detail="Error when upload image to gemini server. details: " + response.text) + raise HTTPException(status_code=response.status_code, detail="Error when upload image to gemini server. details: " + str(response.json())) upload_url = response.headers["x-goog-upload-url"] From 80789d215dc2f3d960011e624bc805347df8d512 Mon Sep 17 00:00:00 2001 From: bloodnighttw Date: Sun, 10 Nov 2024 01:52:03 +0800 Subject: [PATCH 4/6] reset operation in get_chat_session with condition. --- src/detect.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/detect.py b/src/detect.py index d41bd02..9c2897d 100644 --- a/src/detect.py +++ b/src/detect.py @@ -399,13 +399,17 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil genai.configure(api_key=GOOGLE_API_KEY) chat_session = None +list_of_i = [] lock = asyncio.Lock() # If ingredients change, server will need to restart to update the chat session async def get_chat_session(ingredients): async with lock: - global chat_session + global chat_session,list_of_i + if list_of_i != ingredients: + chat_session = None # reset chat session + list_of_i = ingredients if chat_session is None: await run_in_threadpool(init_session, ingredients) return copy.deepcopy(chat_session) @@ -536,7 +540,12 @@ def detect_files(session,file,ingredients_sets): # TODO: depends on get_chat_session with auto fetching ingredients list @detection_router.post("/gemini") -async def detect_by_gemini(files:List[UploadFile] = File(...)): +async def detect_by_gemini(files:List[UploadFile] = File(...)) -> List[str]: + """ + Detect the ingredients in the image by using the gemini API + + :param files: List of images + """ ingredients = ['asparagus', 'avocado', 'bamboo_shoots', 'beans_green', 'beetroot', 'cassava', 'chayote', 'cinnamon', 'coriander', 'corn', 'egg', 'bean_mung', 'cabbage_napa', 'carrot', 'chicken', 'crab', 'garlic', 'mint', 'pepper_bell', 'potato', 'chili', 'eggplant', 'gourd_bitter', 'gourd_bottle', @@ -576,13 +585,4 @@ async def detect_by_gemini(files:List[UploadFile] = File(...)): upload_files = await asyncio.gather(*upload_coroutine) response = await run_in_threadpool(detect_files,session,upload_files,set(ingredients)) - return response - - -# TODO: make it admin only. -@detection_router.post("/gemini/reset") -async def reset_chat_session(): - global chat_session - async with lock: - chat_session = None - return {'message': 'Reset success'} \ No newline at end of file + return response \ No newline at end of file From eca6084f3a7bbc29f2b2e33cc629d1634d120637 Mon Sep 17 00:00:00 2001 From: bloodnighttw Date: Sun, 10 Nov 2024 02:05:32 +0800 Subject: [PATCH 5/6] Stop showing api key in startup. --- src/detect.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/detect.py b/src/detect.py index 9c2897d..4232387 100644 --- a/src/detect.py +++ b/src/detect.py @@ -393,7 +393,6 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil # GOOGLE API KEY GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") -print(f"GOOGLE_API_KEY: {GOOGLE_API_KEY}") BASE_URL = "https://generativelanguage.googleapis.com" genai.configure(api_key=GOOGLE_API_KEY) @@ -448,7 +447,7 @@ def init_session(ingredients): "role": "user", "parts": [ "Detect these shape of food which is on the list " + str( - ingredients) + "on these images down below, ,than map the detected ingredients to the name i provided as an array with field name \"result\". " + ingredients) + " on these images down below, ,than map the detected ingredients to the name i provided as an array with field name \"result\". " ], }, ] From e2c19022d1c22ee52fd060b27428815eab10bfd4 Mon Sep 17 00:00:00 2001 From: bloodnighttw Date: Sun, 10 Nov 2024 18:55:13 +0800 Subject: [PATCH 6/6] Update non-async code to async --- requirements.txt | 106 +++----------------- src/detect.py | 253 +++++++++++++++++++---------------------------- 2 files changed, 113 insertions(+), 246 deletions(-) diff --git a/requirements.txt b/requirements.txt index 8ad6b4f..2f5d0b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,94 +1,12 @@ -aiofiles==23.2.1 -aiosqlite==0.20.0 -alembic==1.13.1 -annotated-types==0.6.0 -anyio==4.3.0 -certifi==2024.2.2 -charset-normalizer==3.3.2 -click==8.1.7 -coloredlogs==15.0.1 -contourpy==1.2.1 -cycler==0.12.1 -fastapi==0.110.0 -filelock==3.14.0 -flatbuffers==24.3.25 -fonttools==4.51.0 -fsspec==2024.3.1 -gitdb==4.0.11 -GitPython==3.1.41 -greenlet==3.0.3 -h11==0.14.0 -humanfriendly==10.0 -idna==3.6 -Jinja2==3.1.4 -kiwisolver==1.4.5 -Mako==1.3.2 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -matplotlib==3.8.4 -mdurl==0.1.2 -mkdocs-git-committers-plugin-2==2.2.3 -mkdocs-git-revision-date-localized-plugin==1.2.4 -mkdocs-material==9.5.6 -mkdocs-material-extensions==1.3.1 -mpmath==1.3.0 -mypy==1.9.0 -mypy-extensions==1.0.0 -networkx==3.3 -numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==8.9.2.26 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.4.127 -nvidia-nvtx-cu12==12.1.105 -nvidia-tensorrt==99.0.0 -onnx==1.16.0 -onnxruntime-gpu==1.18.0 -onnxsim==0.4.36 -opencv-python==4.9.0.80 -packaging==24.0 -pandas==2.2.2 -pillow==10.3.0 -platformdirs==4.2.0 -protobuf==5.26.1 -psutil==5.9.8 -py-cpuinfo==9.0.0 -pydantic==2.6.4 -pydantic_core==2.16.3 -Pygments==2.18.0 -pyparsing==3.1.2 -python-dateutil==2.9.0.post0 -python-multipart==0.0.9 -pytz==2024.1 -PyYAML==6.0.1 -requests==2.31.0 -rich==13.7.1 -scipy==1.13.0 -seaborn==0.13.2 -six==1.16.0 -smmap==5.0.1 -sniffio==1.3.1 -SQLAlchemy==2.0.28 -starlette==0.36.3 -sympy==1.12 -tensorrt==10.0.1 -tensorrt-cu12==10.0.1 -tensorrt-cu12-bindings==10.0.1 -tensorrt-cu12-libs==10.0.1 -thop==0.1.1.post2209072238 -torch==2.3.0 -torchvision==0.18.0 -tqdm==4.66.4 -triton==2.3.0 -typing_extensions==4.10.0 -tzdata==2024.1 -ultralytics==8.2.14 -urllib3==2.2.1 -uvicorn==0.29.0 +fastapi~=0.103.0 +sqlalchemy~=2.0.25 +starlette~=0.27.0 +pydantic~=1.10.12 +pillow~=11.0.0 +aiofiles~=22.1.0 +opencv-python~=4.9.0.80 +numpy~=1.26.4 +aiohttp~=3.10.10 +sahi~=0.11.18 +ultralytics~=8.3.28 +alembic~=1.8.1 \ No newline at end of file diff --git a/src/detect.py b/src/detect.py index 4232387..84deff2 100644 --- a/src/detect.py +++ b/src/detect.py @@ -1,3 +1,5 @@ +import asyncio +import concurrent.futures import copy import io import json @@ -9,17 +11,19 @@ import PIL import aiofiles +import cv2 +import ffmpeg +import google.generativeai as genai import numpy as np -import requests +import aiohttp from PIL import Image, ImageDraw -from fastapi import APIRouter, UploadFile, File, Depends, HTTPException +from fastapi import APIRouter, UploadFile, File, Depends, HTTPException, BackgroundTasks +from google.ai.generativelanguage_v1beta.types import content +from sahi import AutoDetectionModel +from sahi.predict import get_sliced_prediction from sqlalchemy import select, update, delete from starlette.concurrency import run_in_threadpool from ultralytics import YOLO -from sahi import AutoDetectionModel -from sahi.predict import get_sliced_prediction -import asyncio -import concurrent.futures from sql_app.db import AsyncDBSession from sql_app.model.Model import Model @@ -27,12 +31,6 @@ from sql_app.model.User import User from user import token_verify -import cv2 -import ffmpeg - -import google.generativeai as genai -from google.ai.generativelanguage_v1beta.types import content - detection_router = APIRouter(prefix="/detect", tags=['detect']) limit = 0.2 # seconds @@ -41,17 +39,9 @@ detect_process_pool = concurrent.futures.ProcessPoolExecutor(max_workers=2) # each ingredient's confidence threshold -confidence_filter = { - "mushroom": 0.85, - "okra": 0.75, - "heim": 0.85, - "beef": 0.4, - "chicken": 0.4, - "pork": 0.4, - "noodle": 0.85, - "carrot": 0.5, - "common": 0.65 # the ingridient which is not in the filter -} +confidence_filter = {"mushroom": 0.85, "okra": 0.75, "heim": 0.85, "beef": 0.4, "chicken": 0.4, "pork": 0.4, + "noodle": 0.85, "carrot": 0.5, "common": 0.65 # the ingridient which is not in the filter + } def result_processing(img, results): # results:[{"key":[(x1,x2,y1,y2)]}] @@ -69,11 +59,7 @@ def result_processing(img, results): # results:[{"key":[(x1,x2,y1,y2)]}] def image_processing(image, model_path): - model = AutoDetectionModel.from_pretrained( - model_type="yolov8", - device="cuda:0", - model_path=model_path, - ) + model = AutoDetectionModel.from_pretrained(model_type="yolov8", device="cuda:0", model_path=model_path, ) results = get_sliced_prediction(image, model, slice_height=500, slice_width=500, overlap_height_ratio=0.2, overlap_width_ratio=0.2) result = {} @@ -165,8 +151,7 @@ async def detect_image(version: str, db: AsyncDBSession, image: UploadFile = Fil @detection_router.post("/upload/pt") -async def upload_pt(db: AsyncDBSession, description: str, version: str, - user: User = Depends(token_verify), +async def upload_pt(db: AsyncDBSession, description: str, version: str, user: User = Depends(token_verify), pt: UploadFile = File(...)): if user.level <= 127: raise HTTPException(status_code=401, detail='You are not administrator') @@ -197,12 +182,9 @@ async def upload_pt(db: AsyncDBSession, description: str, version: str, if result: path: str = result.file_path os.remove(path) - stmt = (update(Model).where(Model.version == version).values( - file_path=filename, - description=description, - size=size, - update_date=datetime.now() - )) + stmt = ( + update(Model).where(Model.version == version).values(file_path=filename, description=description, size=size, + update_date=datetime.now())) try: await db.execute(stmt) await db.commit() @@ -211,13 +193,8 @@ async def upload_pt(db: AsyncDBSession, description: str, version: str, raise e else: - model_information = Model( - file_path=filename, - description=description, - size=size, - version=version, - update_date=datetime.now() - ) + model_information = Model(file_path=filename, description=description, size=size, version=version, + update_date=datetime.now()) try: db.add(model_information) @@ -399,16 +376,23 @@ async def detect_video(version: str, db: AsyncDBSession, video: UploadFile = Fil chat_session = None list_of_i = [] +ingredients_set = set() lock = asyncio.Lock() # If ingredients change, server will need to restart to update the chat session -async def get_chat_session(ingredients): +async def get_chat_session(db:AsyncDBSession): + async with lock: - global chat_session,list_of_i + global chat_session, list_of_i, ingredients_set + + result = await db.execute(select(Ingredient.name)) + ingredients = [i.name for i in result] + if list_of_i != ingredients: - chat_session = None # reset chat session + chat_session = None # reset chat session list_of_i = ingredients + ingredients_set = set(ingredients) if chat_session is None: await run_in_threadpool(init_session, ingredients) return copy.deepcopy(chat_session) @@ -422,53 +406,37 @@ def init_session(ingredients): "top_p": 0.95, "top_k": 40, "max_output_tokens": 8192, - "response_schema": content.Schema( - type=content.Type.OBJECT, - properties={ - "result": content.Schema( - type=content.Type.ARRAY, - items=content.Schema( - type=content.Type.STRING, - ), - ), - }, - ), + "response_schema": content.Schema(type=content.Type.OBJECT, properties={ + "result": content.Schema(type=content.Type.ARRAY,items=content.Schema(type=content.Type.STRING)), + }), "response_mime_type": "application/json", } - model = genai.GenerativeModel( - model_name="gemini-1.5-pro", - generation_config=generation_config, - system_instruction="Detect food in image (but not text) and map to the name I provided.", - ) + model = genai.GenerativeModel(model_name="gemini-1.5-pro", generation_config=generation_config, + system_instruction="Detect food in image (but not text) and map to the name I provided.", ) - history = [ - { - "role": "user", - "parts": [ - "Detect these shape of food which is on the list " + str( - ingredients) + " on these images down below, ,than map the detected ingredients to the name i provided as an array with field name \"result\". " - ], - }, - ] + history = [{ + "role": "user", + "parts": [ + "Detect these shape of food which is on the list " + + str(ingredients) + " on these images down below," + + "than map the detected ingredients to the name i provided as an array with field name \"result\". " + ], + }] - chat_session = model.start_chat( - history=history - ) + chat_session = model.start_chat(history=history) # This function only supports .jpg or .png files, # and if the file is not a .jpg or .png file, it will -# Why rein -# TODO: make it async call -def upload2gemini(path): +async def upload2gemini(path): # Using Pillow to check if the file is a .jpg or .png file try: with PIL.Image.open(path) as img: if img.format not in ['JPEG', 'PNG']: raise HTTPException(status_code=415, detail='This endpoint only support jpg or png file') except PIL.UnidentifiedImageError: - raise HTTPException(status_code=415, detail='This endpoint is not the type that can be identified by Pillow') + raise HTTPException(status_code=415, detail='We can\'t recognize the file type') size_of_file = os.path.getsize(path) @@ -476,102 +444,80 @@ def upload2gemini(path): 'X-Goog-Upload-Protocol': 'resumable', 'X-Goog-Upload-Command': 'start', 'Content-Type': 'application/json', - 'X-Goog-Upload-Header-Content-Length': str(size_of_file), # header only support str, not int + 'X-Goog-Upload-Header-Content-Length': str(size_of_file) # header only support str, not int } - params = { - 'key': GOOGLE_API_KEY, - } + params = {'key': GOOGLE_API_KEY} filename = os.path.basename(path) data = "{'file': {'display_name': '" + filename + "'}}" - response = requests.post( - f'{BASE_URL}/upload/v1beta/files', - params=params, - headers=headers, - data=data - ) - - if response.status_code != 200: - raise HTTPException(status_code=response.status_code, detail="Error when upload image to gemini server. details: " + str(response.json())) - - upload_url = response.headers["x-goog-upload-url"] - - headers = { - 'X-Goog-Upload-Offset': '0', - 'X-Goog-Upload-Command': 'upload, finalize', - 'Content-Type': 'application/x-www-form-urlencoded', - } - - with open(path, 'rb') as f: - data = f.read() - - res = requests.post(f'{upload_url}', headers=headers, data=data).json()['file'] - args = { - 'name': res['name'], - 'display_name': res['displayName'], - 'mime_type': res['mimeType'], - 'sha256_hash': res['sha256Hash'], - 'size_bytes': res['sizeBytes'], - 'state': res['state'], - 'uri': res['uri'], - 'create_time': res['createTime'], - 'expiration_time': res['expirationTime'], - 'update_time': res['updateTime'] - } - return genai.types.File(args) - -def detect_files(session,file,ingredients_sets): + # response = requests.post(f'{BASE_URL}/upload/v1beta/files', params=params, headers=headers, data=data) + async with aiohttp.ClientSession() as session: + async with session.post(f'{BASE_URL}/upload/v1beta/files', params=params, headers=headers, data=data) as response: + if response.status != 200: + raise HTTPException(status_code=response.status_code, + detail="Error when upload image to gemini server. details: " + str(await response.json())) + + upload_url = response.headers["x-goog-upload-url"] + + headers = { + 'X-Goog-Upload-Offset': '0', + 'X-Goog-Upload-Command': 'upload, finalize', + 'Content-Type': 'application/x-www-form-urlencoded' + } + + async with aiofiles.open(path, 'rb') as file: + data = await file.read() + + # res = requests.post(f'{upload_url}', headers=headers, data=data).json()['file'] + async with session.post(f'{upload_url}', headers=headers, data=data) as response: + if response.status != 200: + raise HTTPException(status_code=response.status_code, + detail="Error when upload image to gemini server. details: " + str(await response.json())) + res = (await response.json())['file'] + + args = { + 'name': res['name'], + 'display_name': res['displayName'], + 'mime_type': res['mimeType'], + 'sha256_hash': res['sha256Hash'], + 'size_bytes': res['sizeBytes'], + 'state': res['state'], + 'uri': res['uri'], + 'create_time': res['createTime'], + 'expiration_time': res['expirationTime'], + 'update_time': res['updateTime'] + } + + return genai.types.File(args) + + +def detect_files(session, file): response = session.send_message(file) output = json.loads(response.text) result = [] # ensure all the ingredients in output are on the set of ingredients for i in output['result']: - if i in ingredients_sets: + if i in ingredients_set: result.append(i) else: print(f"Warning: {i} is not in the list of ingredients.") return result -# TODO: depends on get_chat_session with auto fetching ingredients list + + + @detection_router.post("/gemini") -async def detect_by_gemini(files:List[UploadFile] = File(...)) -> List[str]: +async def detect_by_gemini(background:BackgroundTasks,session = Depends(get_chat_session),files: List[UploadFile] = File(...)) -> List[str]: """ Detect the ingredients in the image by using the gemini API :param files: List of images """ - ingredients = ['asparagus', 'avocado', 'bamboo_shoots', 'beans_green', 'beetroot', 'cassava', 'chayote', 'cinnamon', - 'coriander', 'corn', 'egg', 'bean_mung', 'cabbage_napa', 'carrot', 'chicken', 'crab', 'garlic', - 'mint', 'pepper_bell', 'potato', 'chili', 'eggplant', 'gourd_bitter', 'gourd_bottle', - 'gourd_pointed', 'ham', 'jackfruit', 'lemon', 'mushroom_enoki', 'onion', 'pork', 'potato_sweet', - 'rice', 'almond', 'apple', 'artichoke', 'banana', 'blueberry', 'broccoli', 'broccoli_white', - 'mustard_greens', 'spinach', 'turnip', 'butter', 'cheese', 'milk', 'pasta', 'strawberry', - 'ash_gourd', 'beans_red', 'bokchoy', 'bread', 'brocolli_chinese', 'cabbage', 'cucumber', 'edamame', - 'fish', 'mushroom', 'noodle', 'okra', 'oyster', 'pumpkin', 'radish', 'seaweed', 'taro', 'tomato', - 'tomato_cherry', 'clam', 'burdock', 'peanut', 'spinach_water', 'leek', 'gourd_sponge', 'salmon', - 'apple_wax', 'chives', 'coconut', 'dragon_fruit', 'duck', 'durian', 'frog', 'ginger', 'grape', - 'guava', 'heim', 'kiwi', 'lettuce', 'mango', 'melon_water', 'orange', 'papaya', 'passion_fruit', - 'pineapple', 'potato_leaves', 'prawn', 'spinach_chinese', 'squid', 'tofu', 'zuccini', 'bean_green', - 'beef', 'melon_winter', 'lamb', 'lime', 'bean_sprout', 'tofu_dried', 'tofu_skin', 'ketchup', - 'truffle_sauce', 'miso', 'mayonnaise', 'scallop', 'oats', 'lotus_seed', 'goji', 'jujube', 'quinoa', - 'tomato_paste', 'tomato_can', 'sesame_sauce', 'century_egg', 'baby_corn', 'chili_bean_sauce', - 'basil', 'thyme', 'stokvis', 'sweet_bean_sauce', 'shallot', 'curry', 'yogurt', 'celery', 'stock', - 'sesame', 'soy_sauce', 'lobster', 'crabstick', 'tofu_puff', 'honey', 'yam', 'matcha', 'bean_soy', - 'kimchi', 'sugar_brown', 'egg_salted', 'bacon', 'cream_whip', 'tuna_can', 'paprika', - 'worcestershire_sauce', 'star_anise', 'tsaoko', 'clove', 'sichuan_pepper', 'lotus_root', - 'dried_shrimp', 'sesame_oil', 'mirin', 'sake', 'oyster_sauce', 'chinese_sauerkraut', 'chestnut', - 'shaoxing_wine', 'Chinese_spirits', 'bay_leaf', 'red_wine', 'konjac', 'fish_sauce', 'ginseng', - 'dried_clove_fish', 'bottle_gourd', 'dried_orange_peel', 'dry_beancurd_shreds', 'shacha_sauce', - 'pasta_sauce', 'rice_cake', 'flour', 'gochujang_sause', 'rice-wine', 'rosemary', 'bockwurst', - 'indian_buead', 'euryale_seed', 'coix_seed', 'chinese_angelica', 'longan', 'whisky', 'yeast', - 'sichuan_lovage_rhizome', 'radix_astragali', 'cmnamomi_mmulus', 'blood', 'nutmeg', 'dumpling_skin', - 'black_garlic', 'drinking_yogurt'] - session = await get_chat_session(ingredients) upload_coroutine = [] @@ -579,9 +525,12 @@ async def detect_by_gemini(files:List[UploadFile] = File(...)) -> List[str]: random_name = uuid.uuid4().hex async with aiofiles.open(f"./img/{random_name}", 'wb') as output_file: await output_file.write(await file.read()) - upload_coroutine.append(run_in_threadpool(upload2gemini, f"./img/{random_name}")) + upload_coroutine.append(upload2gemini(f"./img/{random_name}")) upload_files = await asyncio.gather(*upload_coroutine) - response = await run_in_threadpool(detect_files,session,upload_files,set(ingredients)) - return response \ No newline at end of file + response = await run_in_threadpool(detect_files, session, upload_files) + for i in upload_files: + background.add_task(run_in_threadpool,i.delete) + + return response