diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68104c8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# Ignore config.py in the src directory +src/config.py \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..9879a85 --- /dev/null +++ b/main.py @@ -0,0 +1,133 @@ +from flask import Flask, request, jsonify, Response +from flask_cors import CORS +import time +import json +from datetime import datetime +from src.rag import send_chat_message + +app = Flask(__name__) +CORS(app) + +# Memory DB +conversations = {} +conversation_id_list = [] +new_conversation_id = 0 + +# GET /conversations/id-list +@app.route('/conversations/id-list', methods=['GET']) +def get_id_list(): + return jsonify({"list": conversation_id_list}) + +# GET /conversations/{conversation_id} +@app.route('/conversations/', methods=['GET']) +def get_conversation(conversation_id): + conversation = conversations.get(conversation_id) + if conversation: + return jsonify(conversation) + else: + return jsonify({"error": "Conversation not found"}), 404 + +# POST /conversations +# @app.route('/conversations', methods=['POST']) +# def create_conversation(): +# data = request.get_json() + +# # 요청으로부터 필요 데이터 추출 +# conversation_id = 0 +# conversation_id = data.get('data', {}).get('conversation_id', None) +# question = data.get('data', {}).get('question', None) +# engine = data.get('data', {}).get('engine', None) + +# # error +# if (engine is None): +# return jsonify({"error": "Engine is not specified"}), 404 + +# # 기존 대화가 없으면 새로 생성 +# if conversation_id is None: +# global new_conversation_id +# conversation_id = new_conversation_id +# new_conversation_id += 1 +# conversation_id_list.append(conversation_id) +# conversations[conversation_id] = { +# "title": '${conversation_id}', # 일단 대화 id로 지정 +# "engine": engine, +# "create_time": datetime.now(), +# "update_time": datetime.now(), +# "pairing": [] +# } +# elif conversation_id not in conversations: +# return jsonify({"error": "Conversation not found"}), 404 + +# # 답변 생성 및 저장 +# answer = "세종대왕은 한글을 창제하셨습니다." # 여기를 우리가 만든 모델에서 받아오게 추후 수정 +# response_data = { +# 'data': { +# 'conversation_id': conversation_id, +# 'title': conversations[conversation_id]['title'], +# 'answer': answer +# } +# } + +# # 요청 메시지와 응답 메시지 저장 +# conversation_data = { +# "id": len(conversations[conversation_id]['pairing']), +# "request_message": question, +# "response_message": answer, +# "create_time": datetime.now() +# } +# conversations[conversation_id]['pairing'].append(conversation_data) + +# return jsonify(response_data) + +@app.route('/conversations', methods=['POST']) +def mvp_create_conversation(): + data = request.get_json() + + # 요청으로부터 필요 데이터 추출 + conversation_id = data.get('data', {}).get('conversation_id', 0) + question = data.get('data', {}).get('question', None) + engine = data.get('data', {}).get('engine', None) + + # error + if (engine is None): + return jsonify({"error": "Engine is not specified"}), 404 + + # 기존 대화가 없으면 새로 생성 + if conversation_id is None: + global new_conversation_id + conversation_id = new_conversation_id + new_conversation_id += 1 + conversation_id_list.append(conversation_id) + conversations[conversation_id] = { + "title": '${conversation_id}', # 일단 대화 id로 지정 + "engine": engine, + "create_time": datetime.now(), + "update_time": datetime.now(), + "pairing": [] + } + elif conversation_id not in conversations: + return jsonify({"error": "Conversation not found"}), 404 + + # 답변 생성 및 저장 + answer = send_chat_message(question) # 여기를 우리가 만든 모델에서 받아오게 추후 수정 + response_data = { + 'data': { + 'conversation_id': conversation_id, + 'title': "", + 'answer': answer + } + } + + # 요청 메시지와 응답 메시지 저장 + conversation_data = { + "id": len(conversations[conversation_id]['pairing']), + "request_message": question, + "response_message": answer, + "create_time": datetime.now() + } + conversations[conversation_id]['pairing'].append(conversation_data) + + return jsonify(response_data) + +if __name__ == '__main__': + app.run(debug=True) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c4edea7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +requests +google-auth +google-auth-oauthlib +google-cloud-aiplatform +langchain +flask +flask-cors +langchain_google_vertexai \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/rag.py b/src/rag.py new file mode 100644 index 0000000..88f5c9b --- /dev/null +++ b/src/rag.py @@ -0,0 +1,82 @@ +import requests +# from vertexai.generative_models import ( +# FunctionDeclaration, +# GenerationConfig, +# GenerativeModel, +# Part, +# Tool, +# ) +from .vertex_ai_search import search_pdf + +# def tool_declarate() -> Tool: +# # 항공 운임표 정보를 조회하는 기능 +# get_flight_fare_info = FunctionDeclaration( +# name="get_flight_fare_info", +# description="Retrieve flight fare information from a PDF based on departure and destination", +# parameters={ +# "type": "object", +# "properties": { +# "departure": {"type": "string", "description": "Departure city"}, +# "destination": {"type": "string", "description": "Destination city"}, +# }, +# "required": ["departure", "destination"] +# }, +# ) + + +# # 여객 운송 약관(탑승 수속, 수하물 규정 등)을 조회하는 기능 +# get_transport_policy_info = FunctionDeclaration( +# name="get_transport_policy_info", +# description="Retrieve transportation policy information such as check-in and baggage rules from a PDF", +# parameters={ +# "type": "object", +# "properties": { +# "policy_type": {"type": "string", "description": "Type of policy, e.g., check-in, baggage rules"}, +# }, +# "required": ["policy_type"] +# }, +# ) + +# # # 항공편 정보 조회하는 기능 +# # get_flight_info = FunctionDeclaration( +# # name="get_transport_policy_info", +# # ) + +# flight_tool = Tool( +# function_declarations=[ +# get_flight_fare_info, +# get_transport_policy_info, +# # get_flight_info, +# ], +# ) +# return flight_tool + +def send_chat_message(prompt:str) -> str: + # tools=[tool_declarate()] + + # model = GenerativeModel( + # "gemini-1.5-pro-001", + # generation_config=GenerationConfig(temperature=0), + # tools=tools, + # ) + # chat = model.start_chat() + + # print("prompt: " + prompt + "\n") + # prompt += """ + # Give a concise, high-level summary. Only use information that you learn from + # the API responses. + # """ + + # # Gemini로 채팅 메시지 보내기 + # response = chat.send_message(prompt) + + # # 함수 호출의 응답에서 값 추출 + # function_call = response.candidates[0].content.parts[0].function_call + + # # Gemini가 판단하여 호출한 함수 + # selected_function_name = function_call.name + # print("selected_function_name: " + selected_function_name, "\n") + answer_response = search_pdf(prompt) + return answer_response + + \ No newline at end of file diff --git a/src/vertex_ai_search.py b/src/vertex_ai_search.py new file mode 100644 index 0000000..564b6b1 --- /dev/null +++ b/src/vertex_ai_search.py @@ -0,0 +1,147 @@ +import vertexai +import google +import google.oauth2.credentials +from google.auth import compute_engine +import google.auth.transport.requests +import requests +import json +import os +from langchain_google_vertexai.llms import VertexAI +from langchain.prompts import PromptTemplate +from langchain.chains import LLMChain +#from .config import PROJECT_ID, REGION, MODEL, SEARCH_URL +PROJECT_ID = "" +REGION="asia-northeast3" +MODEL = "gemini-1.5-pro-001" +SEARCH_URL = "" + +import google.auth +import google.auth.transport.requests +creds, project = google.auth.default() +auth_req = google.auth.transport.requests.Request() +creds.refresh(auth_req) +id_token = creds.token + + +def retrieve_vertex_ai_search(question:str, search_url:str, page_size:int) -> str: + # stream = os.popen('gcloud auth print-access-token') + # credential_token = stream.read().strip() + + """ retrieve information from enterprise search ( discovery engine )""" + + # Create a credentials token to call a REST API + headers = { + "Authorization": "Bearer "+ id_token, + "Content-Type": "application/json" + } + + + query_dic ={ + "query": question, + "page_size": str(page_size), + "offset": 0, + "contentSearchSpec":{ + # "snippetSpec": {"maxSnippetCount": 5, + # }, + # "summarySpec": { "summaryResultCount": 5, + # "includeCitations": True}, + "extractiveContentSpec":{ + #"maxExtractiveAnswerCount": 3, + "maxExtractiveSegmentCount": 2, + "num_previous_segments" : 1, + "num_next_segments" : 1, + "return_extractive_segment_score" : True + } + }, + # "queryExpansionSpec":{"condition":"AUTO"} + } + + data = json.dumps(query_dic) + + # Encode data as UTF8 + data=data.encode("utf8") + + response = requests.post(search_url,headers=headers, data=data) + + print(response.text) + return response.text + +def parse_discovery_results(response_text: str) -> dict: + """Parse response to build a context to be sent to LLM.""" + + # JSON 문자열을 파이썬 객체로 변환 + dict_results = json.loads(response_text) + + result_index = 0 + searched_ctx_dic = {} + + # 결과가 있는지 확인 + if dict_results.get('results'): + for result in dict_results['results']: + answer_ctx = "" # 답변 내용을 저장할 변수 + segments_ctx = "" # 세그먼트 내용을 저장할 변수 + + # 문서 링크 가져오기 + reference = result['document']['derivedStructData']['link'] + derivedStructData = result['document']['derivedStructData'] + + # 'extractive_answers'가 존재하는지 확인하고, 답변을 추출 + if 'extractive_answers' in derivedStructData and derivedStructData['extractive_answers']: + for answer in derivedStructData['extractive_answers']: + answer_ctx += answer.get('content', '') # 안전하게 내용 추가 + + # 'extractive_segments'가 존재하는지 확인하고, 세그먼트를 추출 + if 'extractive_segments' in derivedStructData and derivedStructData['extractive_segments']: + for segment in derivedStructData['extractive_segments']: + segments_ctx += segment.get('content', '') # 안전하게 내용 추가 + + # HTML 태그 및 인코딩 제거 + answer_ctx = answer_ctx.replace("", "").replace("", "").replace(""", "") + segments_ctx = segments_ctx.replace("", "").replace("", "").replace(""", "") + + # Google Cloud Storage 링크로 변환 + reference_link = reference.replace("gs://", "https://storage.cloud.google.com/") + + # 결과를 딕셔너리에 저장 + item = { + 'answer_ctx': answer_ctx, + 'segments_ctx': segments_ctx, + 'reference_link': reference_link + } + + searched_ctx_dic[f"Searched Context {result_index}"] = item + result_index += 1 + + return searched_ctx_dic + +def search_pdf(question:str, SEARCH_URL:str = SEARCH_URL) -> str: + gemini_pro = VertexAI( model_name = MODEL, + project=PROJECT_ID, + location=REGION, + verbose=True, + streaming=False, + temperature = 0.2, + top_p = 1, + top_k = 40 + ) + + page_size = 5 + + searched_ctx = retrieve_vertex_ai_search(question, SEARCH_URL, page_size) + context = parse_discovery_results(searched_ctx) + + prompt = PromptTemplate.from_template(""" + + 당신은 항공사 CS AI 어시스턴트입니다. + 아래 Question 에 대해서 반드시 Context에 있는 개별 내용을 기반으로 단계적으로 추론해서 근거를 설명하고 답변해주세요. + Context : {context} + Question : {question} + + """) + + prompt = prompt.format(context=context, question=question) + + print(f"Prompt : {prompt}") + + response = gemini_pro.invoke(prompt) + return response \ No newline at end of file