1
- import logging
2
1
import os
3
2
import pickle
4
- from datetime import datetime
5
- from typing import Optional
6
3
7
4
import chromadb
8
5
import gradio as gr
9
6
import logfire
10
7
from custom_retriever import CustomRetriever
11
8
from dotenv import load_dotenv
12
9
from llama_index .agent .openai import OpenAIAgent
13
- from llama_index .core import VectorStoreIndex , get_response_synthesizer
14
- from llama_index .core .agent import AgentRunner , ReActAgent
15
-
16
- # from llama_index.core.chat_engine import (
17
- # CondensePlusContextChatEngine,
18
- # CondenseQuestionChatEngine,
19
- # ContextChatEngine,
20
- # )
21
- from llama_index .core .data_structs import Node
10
+ from llama_index .core import VectorStoreIndex
22
11
from llama_index .core .llms import MessageRole
23
12
from llama_index .core .memory import ChatMemoryBuffer
24
13
from llama_index .core .node_parser import SentenceSplitter
25
- from llama_index .core .query_engine import RetrieverQueryEngine
26
14
from llama_index .core .retrievers import VectorIndexRetriever
27
- from llama_index .core .tools import (
28
- FunctionTool ,
29
- QueryEngineTool ,
30
- RetrieverTool ,
31
- ToolMetadata ,
32
- )
33
-
34
- # from llama_index.core.vector_stores import (
35
- # ExactMatchFilter,
36
- # FilterCondition,
37
- # FilterOperator,
38
- # MetadataFilter,
39
- # MetadataFilters,
40
- # )
15
+ from llama_index .core .tools import RetrieverTool , ToolMetadata
41
16
from llama_index .embeddings .openai import OpenAIEmbedding
42
- from llama_index .llms .gemini import Gemini
43
17
from llama_index .llms .openai import OpenAI
44
- from llama_index .llms .openai .utils import GPT4_MODELS
45
18
from llama_index .vector_stores .chroma import ChromaVectorStore
46
- from tutor_prompts import (
47
- TEXT_QA_TEMPLATE ,
48
- QueryValidation ,
49
- system_message_openai_agent ,
50
- system_message_validation ,
51
- system_prompt ,
52
- )
53
-
54
- load_dotenv ()
55
-
19
+ from tutor_prompts import system_message_openai_agent
56
20
57
21
# from utils import init_mongo_db
58
22
59
- logging . getLogger ( "gradio" ). setLevel ( logging . INFO )
60
- logging . getLogger ( "httpx" ). setLevel ( logging . WARNING )
23
+ load_dotenv ( )
24
+
61
25
logfire .configure ()
62
- # logging.basicConfig(handlers=[logfire.LogfireLoggingHandler("INFO")])
63
- # logger = logging.getLogger(__name__)
64
26
65
- # # This variables are used to intercept API calls
66
- # # launch mitmweb
67
- # cert_file = "/Users/omar/Documents/mitmproxy-ca-cert.pem"
68
- # os.environ["REQUESTS_CA_BUNDLE"] = cert_file
69
- # os.environ["SSL_CERT_FILE"] = cert_file
70
- # os.environ["HTTPS_PROXY"] = "http://127.0.0.1:8080"
71
27
72
28
CONCURRENCY_COUNT = int (os .getenv ("CONCURRENCY_COUNT" , 64 ))
73
29
MONGODB_URI = os .getenv ("MONGODB_URI" )
131
87
use_async = True ,
132
88
)
133
89
vector_retriever = VectorIndexRetriever (
134
- # filters=filters,
135
90
index = index ,
136
91
similarity_top_k = 10 ,
137
92
use_async = True ,
@@ -204,12 +159,10 @@ def generate_completion(
204
159
chat_list = memory .get ()
205
160
206
161
if len (chat_list ) != 0 :
207
- # Compute number of interactions
208
162
user_index = [
209
163
i for i , msg in enumerate (chat_list ) if msg .role == MessageRole .USER
210
164
]
211
165
if len (user_index ) > len (history ):
212
- # A message was removed, need to update the memory
213
166
user_index_to_remove = user_index [len (history )]
214
167
chat_list = chat_list [:user_index_to_remove ]
215
168
memory .set (chat_list )
@@ -237,40 +190,9 @@ def generate_completion(
237
190
# )
238
191
# custom_retriever = CustomRetriever(vector_retriever, document_dict)
239
192
240
- if model == "gemini-1.5-flash" or model == "gemini-1.5-pro" :
241
- llm = Gemini (
242
- api_key = os .getenv ("GOOGLE_API_KEY" ),
243
- model = f"models/{ model } " ,
244
- temperature = 1 ,
245
- max_tokens = None ,
246
- )
247
- else :
248
- llm = OpenAI (temperature = 1 , model = model , max_tokens = None )
249
- client = llm ._get_client ()
250
- logfire .instrument_openai (client )
251
-
252
- # response_synthesizer = get_response_synthesizer(
253
- # llm=llm,
254
- # response_mode="simple_summarize",
255
- # text_qa_template=TEXT_QA_TEMPLATE,
256
- # streaming=True,
257
- # )
258
-
259
- # custom_query_engine = RetrieverQueryEngine(
260
- # retriever=custom_retriever,
261
- # response_synthesizer=response_synthesizer,
262
- # )
263
-
264
- # agent = CondensePlusContextChatEngine.from_defaults(
265
- # agent = CondenseQuestionChatEngine.from_defaults(
266
-
267
- # agent = ContextChatEngine.from_defaults(
268
- # retriever=custom_retriever,
269
- # context_template=system_prompt,
270
- # llm=llm,
271
- # memory=memory,
272
- # verbose=True,
273
- # )
193
+ llm = OpenAI (temperature = 1 , model = model , max_tokens = None )
194
+ client = llm ._get_client ()
195
+ logfire .instrument_openai (client )
274
196
275
197
query_engine_tools = [
276
198
RetrieverTool (
@@ -282,23 +204,13 @@ def generate_completion(
282
204
)
283
205
]
284
206
285
- if model == "gemini-1.5-flash" or model == "gemini-1.5-pro" :
286
- agent = AgentRunner .from_llm (
287
- llm = llm ,
288
- tools = query_engine_tools , # type: ignore
289
- verbose = True ,
290
- memory = memory ,
291
- # system_prompt=system_message_openai_agent,
292
- )
293
- else :
294
- agent = OpenAIAgent .from_tools (
295
- llm = llm ,
296
- memory = memory ,
297
- tools = query_engine_tools , # type: ignore
298
- system_prompt = system_message_openai_agent ,
299
- )
207
+ agent = OpenAIAgent .from_tools (
208
+ llm = llm ,
209
+ memory = memory ,
210
+ tools = query_engine_tools , # type: ignore
211
+ system_prompt = system_message_openai_agent ,
212
+ )
300
213
301
- # completion = custom_query_engine.query(query)
302
214
completion = agent .stream_chat (query )
303
215
304
216
answer_str = ""
0 commit comments