Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Ignore config.py in the src directory
src/config.py
133 changes: 133 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
import time
import json
from datetime import datetime
from src.rag import send_chat_message

app = Flask(__name__)
CORS(app)

# Memory DB
conversations = {}
conversation_id_list = []
new_conversation_id = 0

# GET /conversations/id-list
@app.route('/conversations/id-list', methods=['GET'])
def get_id_list():
return jsonify({"list": conversation_id_list})

# GET /conversations/{conversation_id}
@app.route('/conversations/<int:conversation_id>', methods=['GET'])
def get_conversation(conversation_id):
conversation = conversations.get(conversation_id)
if conversation:
return jsonify(conversation)
else:
return jsonify({"error": "Conversation not found"}), 404

# POST /conversations
# @app.route('/conversations', methods=['POST'])
# def create_conversation():
# data = request.get_json()

# # 요청으로부터 필요 데이터 추출
# conversation_id = 0
# conversation_id = data.get('data', {}).get('conversation_id', None)
# question = data.get('data', {}).get('question', None)
# engine = data.get('data', {}).get('engine', None)

# # error
# if (engine is None):
# return jsonify({"error": "Engine is not specified"}), 404

# # 기존 대화가 없으면 새로 생성
# if conversation_id is None:
# global new_conversation_id
# conversation_id = new_conversation_id
# new_conversation_id += 1
# conversation_id_list.append(conversation_id)
# conversations[conversation_id] = {
# "title": '${conversation_id}', # 일단 대화 id로 지정
# "engine": engine,
# "create_time": datetime.now(),
# "update_time": datetime.now(),
# "pairing": []
# }
# elif conversation_id not in conversations:
# return jsonify({"error": "Conversation not found"}), 404

# # 답변 생성 및 저장
# answer = "세종대왕은 한글을 창제하셨습니다." # 여기를 우리가 만든 모델에서 받아오게 추후 수정
# response_data = {
# 'data': {
# 'conversation_id': conversation_id,
# 'title': conversations[conversation_id]['title'],
# 'answer': answer
# }
# }

# # 요청 메시지와 응답 메시지 저장
# conversation_data = {
# "id": len(conversations[conversation_id]['pairing']),
# "request_message": question,
# "response_message": answer,
# "create_time": datetime.now()
# }
# conversations[conversation_id]['pairing'].append(conversation_data)

# return jsonify(response_data)

@app.route('/conversations', methods=['POST'])
def mvp_create_conversation():
data = request.get_json()

# 요청으로부터 필요 데이터 추출
conversation_id = data.get('data', {}).get('conversation_id', 0)
question = data.get('data', {}).get('question', None)
engine = data.get('data', {}).get('engine', None)

# error
if (engine is None):
return jsonify({"error": "Engine is not specified"}), 404

# 기존 대화가 없으면 새로 생성
if conversation_id is None:
global new_conversation_id
conversation_id = new_conversation_id
new_conversation_id += 1
conversation_id_list.append(conversation_id)
conversations[conversation_id] = {
"title": '${conversation_id}', # 일단 대화 id로 지정
"engine": engine,
"create_time": datetime.now(),
"update_time": datetime.now(),
"pairing": []
}
elif conversation_id not in conversations:
return jsonify({"error": "Conversation not found"}), 404

# 답변 생성 및 저장
answer = send_chat_message(question) # 여기를 우리가 만든 모델에서 받아오게 추후 수정
response_data = {
'data': {
'conversation_id': conversation_id,
'title': "",
'answer': answer
}
}

# 요청 메시지와 응답 메시지 저장
conversation_data = {
"id": len(conversations[conversation_id]['pairing']),
"request_message": question,
"response_message": answer,
"create_time": datetime.now()
}
conversations[conversation_id]['pairing'].append(conversation_data)

return jsonify(response_data)

if __name__ == '__main__':
app.run(debug=True)
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
requests
google-auth
google-auth-oauthlib
google-cloud-aiplatform
langchain
flask
flask-cors
langchain_google_vertexai
Empty file added src/__init__.py
Empty file.
82 changes: 82 additions & 0 deletions src/rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import requests
# from vertexai.generative_models import (
# FunctionDeclaration,
# GenerationConfig,
# GenerativeModel,
# Part,
# Tool,
# )
from .vertex_ai_search import search_pdf

# def tool_declarate() -> Tool:
# # 항공 운임표 정보를 조회하는 기능
# get_flight_fare_info = FunctionDeclaration(
# name="get_flight_fare_info",
# description="Retrieve flight fare information from a PDF based on departure and destination",
# parameters={
# "type": "object",
# "properties": {
# "departure": {"type": "string", "description": "Departure city"},
# "destination": {"type": "string", "description": "Destination city"},
# },
# "required": ["departure", "destination"]
# },
# )


# # 여객 운송 약관(탑승 수속, 수하물 규정 등)을 조회하는 기능
# get_transport_policy_info = FunctionDeclaration(
# name="get_transport_policy_info",
# description="Retrieve transportation policy information such as check-in and baggage rules from a PDF",
# parameters={
# "type": "object",
# "properties": {
# "policy_type": {"type": "string", "description": "Type of policy, e.g., check-in, baggage rules"},
# },
# "required": ["policy_type"]
# },
# )

# # # 항공편 정보 조회하는 기능
# # get_flight_info = FunctionDeclaration(
# # name="get_transport_policy_info",
# # )

# flight_tool = Tool(
# function_declarations=[
# get_flight_fare_info,
# get_transport_policy_info,
# # get_flight_info,
# ],
# )
# return flight_tool

def send_chat_message(prompt:str) -> str:
# tools=[tool_declarate()]

# model = GenerativeModel(
# "gemini-1.5-pro-001",
# generation_config=GenerationConfig(temperature=0),
# tools=tools,
# )
# chat = model.start_chat()

# print("prompt: " + prompt + "\n")
# prompt += """
# Give a concise, high-level summary. Only use information that you learn from
# the API responses.
# """

# # Gemini로 채팅 메시지 보내기
# response = chat.send_message(prompt)

# # 함수 호출의 응답에서 값 추출
# function_call = response.candidates[0].content.parts[0].function_call

# # Gemini가 판단하여 호출한 함수
# selected_function_name = function_call.name
# print("selected_function_name: " + selected_function_name, "\n")
answer_response = search_pdf(prompt)
return answer_response


147 changes: 147 additions & 0 deletions src/vertex_ai_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import vertexai
import google
import google.oauth2.credentials
from google.auth import compute_engine
import google.auth.transport.requests
import requests
import json
import os
from langchain_google_vertexai.llms import VertexAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
#from .config import PROJECT_ID, REGION, MODEL, SEARCH_URL
PROJECT_ID = ""
REGION="asia-northeast3"
MODEL = "gemini-1.5-pro-001"
SEARCH_URL = ""

import google.auth
import google.auth.transport.requests
creds, project = google.auth.default()
auth_req = google.auth.transport.requests.Request()
creds.refresh(auth_req)
id_token = creds.token


def retrieve_vertex_ai_search(question:str, search_url:str, page_size:int) -> str:
# stream = os.popen('gcloud auth print-access-token')
# credential_token = stream.read().strip()

""" retrieve information from enterprise search ( discovery engine )"""

# Create a credentials token to call a REST API
headers = {
"Authorization": "Bearer "+ id_token,
"Content-Type": "application/json"
}


query_dic ={
"query": question,
"page_size": str(page_size),
"offset": 0,
"contentSearchSpec":{
# "snippetSpec": {"maxSnippetCount": 5,
# },
# "summarySpec": { "summaryResultCount": 5,
# "includeCitations": True},
"extractiveContentSpec":{
#"maxExtractiveAnswerCount": 3,
"maxExtractiveSegmentCount": 2,
"num_previous_segments" : 1,
"num_next_segments" : 1,
"return_extractive_segment_score" : True
}
},
# "queryExpansionSpec":{"condition":"AUTO"}
}

data = json.dumps(query_dic)

# Encode data as UTF8
data=data.encode("utf8")

response = requests.post(search_url,headers=headers, data=data)

print(response.text)
return response.text

def parse_discovery_results(response_text: str) -> dict:
"""Parse response to build a context to be sent to LLM."""

# JSON 문자열을 파이썬 객체로 변환
dict_results = json.loads(response_text)

result_index = 0
searched_ctx_dic = {}

# 결과가 있는지 확인
if dict_results.get('results'):
for result in dict_results['results']:
answer_ctx = "" # 답변 내용을 저장할 변수
segments_ctx = "" # 세그먼트 내용을 저장할 변수

# 문서 링크 가져오기
reference = result['document']['derivedStructData']['link']
derivedStructData = result['document']['derivedStructData']

# 'extractive_answers'가 존재하는지 확인하고, 답변을 추출
if 'extractive_answers' in derivedStructData and derivedStructData['extractive_answers']:
for answer in derivedStructData['extractive_answers']:
answer_ctx += answer.get('content', '') # 안전하게 내용 추가

# 'extractive_segments'가 존재하는지 확인하고, 세그먼트를 추출
if 'extractive_segments' in derivedStructData and derivedStructData['extractive_segments']:
for segment in derivedStructData['extractive_segments']:
segments_ctx += segment.get('content', '') # 안전하게 내용 추가

# HTML 태그 및 인코딩 제거
answer_ctx = answer_ctx.replace("<b>", "").replace("</b>", "").replace("&quot;", "")
segments_ctx = segments_ctx.replace("<b>", "").replace("</b>", "").replace("&quot;", "")

# Google Cloud Storage 링크로 변환
reference_link = reference.replace("gs://", "https://storage.cloud.google.com/")

# 결과를 딕셔너리에 저장
item = {
'answer_ctx': answer_ctx,
'segments_ctx': segments_ctx,
'reference_link': reference_link
}

searched_ctx_dic[f"Searched Context {result_index}"] = item
result_index += 1

return searched_ctx_dic

def search_pdf(question:str, SEARCH_URL:str = SEARCH_URL) -> str:
gemini_pro = VertexAI( model_name = MODEL,
project=PROJECT_ID,
location=REGION,
verbose=True,
streaming=False,
temperature = 0.2,
top_p = 1,
top_k = 40
)

page_size = 5

searched_ctx = retrieve_vertex_ai_search(question, SEARCH_URL, page_size)
context = parse_discovery_results(searched_ctx)

prompt = PromptTemplate.from_template("""

당신은 항공사 CS AI 어시스턴트입니다.
아래 Question 에 대해서 반드시 Context에 있는 개별 내용을 기반으로 단계적으로 추론해서 근거를 설명하고 답변해주세요.
Context : {context}
Question : {question}

""")

prompt = prompt.format(context=context, question=question)

print(f"Prompt : {prompt}")

response = gemini_pro.invoke(prompt)
return response