-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAstroMind.py
More file actions
209 lines (159 loc) · 6.08 KB
/
AstroMind.py
File metadata and controls
209 lines (159 loc) · 6.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import streamlit as st
from typing import List
from dotenv import load_dotenv
from src.constants import *
from src.vdb import VectorDB
from src.vdb_qdrant import QdrantVectorDB
from src.llm import LLM
from src.llm_openai import OpenAILLM
from src.embedder import BAAIEmbedder
from src.embedding_pipeline import EmbeddingPipeline
from src.ships.ships_embedding_pipeline import ShipsEmbeddingPipeline
#=== Ask ===#
@st.cache_resource
def run_pipelines(_pipelines):
for pipeline in _pipelines:
result = pipeline.start()
if result == EmbeddingPipeline.FAILURE:
st.error(f"Failed to embed the dataset")
def chat_ui(llm: LLM, vdb: VectorDB, pipelines: List[EmbeddingPipeline]):
st.title("Astro Mind")
st.write("""
Welcome to Astro Mind, your friendly AI assistant for Elite Dangerous commanders!
I'm here to help you explore and understand the capabilities of ships across the galaxy.
Whether you're a seasoned explorer or new to the game, I can provide detailed information about ships,
their specifications, and more.
To get started, you can ask me questions like:
- "What are the specs of the Anaconda?"
- "How does the Python compare to the Cobra Mk IV?"
- "What are the key features of the Asp Explorer?"
Let's dive into the galaxy together!
""")
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
for message in st.session_state.conversation_history:
with st.chat_message("user" if message.startswith("User:") else "assistant"):
st.markdown(message.replace("User: ", "").replace("Assistant: ", ""))
# Show example prompts on first load
if 'show_examples' not in st.session_state:
st.session_state.show_examples = True
if st.session_state.show_examples:
st.header("Example Questions")
st.write("""
Here are some examples of questions you can ask me:
- "What are the specs of the Anaconda?"
- "How does the Python compare to the Cobra Mk IV?"
- "What are the key features of the Asp Explorer?"
- "Tell me about the best exploration ships"
- "What's the top choice for a combat vessel?"
You can ask about any ship in Elite Dangerous, and I'll provide detailed information!
""")
if st.button("Got it, let's start!"):
st.session_state.show_examples = False
st.rerun()
user_query = st.chat_input("Hello commander, how can I help you today?")
if user_query:
with st.chat_message("user"):
st.markdown(user_query)
st.session_state.conversation_history.append(f"User: {user_query}")
topic = ""
if 'topic' not in st.session_state:
topic = llm.ask(
user_query,
"""What ship are they talking about?
Reply only with the name of the ship using the following format:
The query relates to the ship: <shipname>.
Replace <shipname> by the actual shipname the user are talking about.
"""
)
st.session_state.topic = topic
else:
topic = st.session_state.topic
context = vdb.search(f"{user_query}\n{topic}", SHIPS_COLLECTION_NAME)
context.append(topic)
response = llm.ask(context, user_query)
st.session_state.conversation_history.append(f"Assistant: {response}")
with st.chat_message("assistant"):
st.markdown(response)
with st.sidebar:
st.header("Debug", divider=True)
st.button(
"Embed documents",
on_click=run_pipelines,
key="embed_button"
)
# Add reset button
if st.button("⚠️ Reset All Resources", key="reset_button"):
# Clear cached resources
st.cache_resource.clear()
# Close vector database connection
if hasattr(vdb, 'close'):
vdb.close()
# Clear session state
st.session_state.clear()
# Rerun the app to reinitialize resources
st.rerun()
def ui(pipelines: List[EmbeddingPipeline], llm: LLM, vdb: VectorDB):
chat_ui(llm, vdb, pipelines)
#=== Setup ===#
def setup_environment():
if "env_setup" in st.session_state:
return
load_dotenv()
os.environ["HF_TOKEN"] = os.getenv(EMBEDDER_LLM_API_KEY)
st.session_state["env_setup"] = True
@st.cache_resource
def setup_inference_llm():
inference_llm = OpenAILLM(
provider=os.getenv(INFERENCE_LLM_PROVIDER),
api_key=os.getenv(INFERENCE_LLM_API_KEY),
model=os.getenv(INFERENCE_LLM_MODEL),
url=os.getenv(INFERENCE_LLM_URL)
)
inference_llm.user_prompt = """
Use the following pieces of information enclosed in <context> tags to provide an answer to the question enclosed in <question> tags.
<context>
{context}
</context>
<question>
{query}
</question>
"""
return inference_llm
@st.cache_resource
def setup_llm():
inference_llm = setup_inference_llm()
return inference_llm
@st.cache_resource
def setup_vector_db():
vector_db: VectorDB = QdrantVectorDB(embedder=BAAIEmbedder(), db_path=LOCAL_VECTOR_DB_FILE)
return vector_db
#=== Main ===#
def main():
#--- Setup ---#
setup_environment()
try:
vdb = setup_vector_db()
llm = setup_llm()
except RuntimeError as e:
st.error(f"Resource conflict error: {str(e)}")
st.error("Please reset the resources using the 'Reset All Resources' button.")
# Show only the reset button
if st.button("⚠️ Reset All Resources"):
st.cache_resource.clear()
st.session_state.clear()
st.rerun()
return
#--- Embedding ---#
ships_embedding_pipeline = ShipsEmbeddingPipeline(vdb)
pipelines = [ships_embedding_pipeline]
run_pipelines(pipelines)
#--- Ask ---#
ui(
pipelines,
llm,
vdb
)
if __name__ == "__main__":
main()