-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
183 lines (146 loc) · 7.01 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
import pickle
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from dotenv import load_dotenv
import os
load_dotenv()
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
class PlantClassifier:
def __init__(self, model_path, class_labels_path):
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._load_model(model_path)
self.class_labels = self._load_class_labels(class_labels_path)
self.transform = self._setup_transform()
def _load_model(self, model_path):
try:
with open(model_path, 'rb') as f:
model = pickle.load(f)
model.to(self.device)
model.eval()
return model
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
def _load_class_labels(self, file_path):
try:
with open(file_path, "r") as file:
return [line.strip() for line in file.readlines()]
except Exception as e:
raise RuntimeError(f"Failed to load class labels: {str(e)}")
def _setup_transform(self):
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(self, image):
try:
image = Image.open(image).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
with torch.no_grad():
outputs = self.model(image_tensor)
logits = outputs.logits if hasattr(outputs, 'logits') else outputs
probabilities = torch.nn.functional.softmax(logits, dim=1)
confidence, predicted = torch.max(probabilities, 1)
return self.class_labels[predicted.item()], confidence.item() * 100
except Exception as e:
raise RuntimeError(f"Prediction failed: {str(e)}")
class PlantInfoRetriever:
def __init__(self, api_key):
self.chat_model = ChatGoogleGenerativeAI(
api_key=api_key,
model="gemini-pro",
temperature=0.3
)
self.prompt_template = PromptTemplate(
template="""
Please answer the given question using the context provided. Make the response short and concise, about 150 words.
You may enhance the response using relevant knowledge only if it aligns with the context.\n\n
Question: \n{question}\n
Answer:
""",
input_variables=["question"]
)
def get_information(self, query):
try:
formatted_query = self.prompt_template.format(question=query)
response = self.chat_model.invoke(formatted_query)
return response.content
except Exception as e:
return f"An error occurred while fetching information: {str(e)}"
def main():
st.set_page_config(page_title="PlantPedia 🌿", layout="wide")
if not GEMINI_API_KEY:
st.error("Gemini API Key not found. Please check your environment variables.")
return
if 'image_uploaded' not in st.session_state:
st.session_state.image_uploaded = False
if 'requests' not in st.session_state:
st.session_state.requests = []
if 'responses' not in st.session_state:
st.session_state.responses = []
if 'predicted_label' not in st.session_state:
st.session_state.predicted_label = None
if 'initial_question_asked' not in st.session_state:
st.session_state.initial_question_asked = False
if 'uploaded_image' not in st.session_state:
st.session_state.uploaded_image = None
try:
classifier = PlantClassifier(
model_path='/fab3/btech/2021/manish.kumar21b/manish/PlantPedia/endsem/model/trained_model.pkl',
class_labels_path="/fab3/btech/2021/manish.kumar21b/manish/PlantPedia/endsem/class_labels.txt"
)
info_retriever = PlantInfoRetriever(GEMINI_API_KEY)
except Exception as e:
st.error(f"Failed to initialize application: {str(e)}")
return
st.title("PlantPedia 🌿🧠")
if st.sidebar.button("New Chat"):
for key in st.session_state.keys():
del st.session_state[key]
st.rerun()
st.subheader("Upload an image of a plant leaf and click 'Predict' to learn more about it.")
uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
if uploaded_image and st.button("Predict"):
try:
st.session_state.uploaded_image = uploaded_image
st.session_state.image_uploaded = True
with st.spinner("Predicting..."):
predicted_label, confidence_score = classifier.predict(uploaded_image)
st.session_state.predicted_label = predicted_label
st.session_state.confidence_score = confidence_score
if not st.session_state.initial_question_asked:
initial_question = f"What is a {predicted_label} plant leaf and can you provide some interesting facts about it?"
response = info_retriever.get_information(initial_question)
st.session_state.requests.append(initial_question)
st.session_state.responses.append(response)
st.session_state.initial_question_asked = True
except Exception as e:
st.error(f"Error during prediction: {str(e)}")
predicted_label = "Prediction failed"
confidence_score = 0.0
if st.session_state.image_uploaded and st.session_state.uploaded_image:
st.image(st.session_state.uploaded_image, caption="Uploaded Image", use_column_width=True)
if st.session_state.predicted_label:
st.success(f"Predicted Label: {st.session_state.predicted_label} (Confidence: {st.session_state.confidence_score:.2f}%)")
if st.session_state.requests:
for request, response in zip(st.session_state.requests, st.session_state.responses):
with st.container():
st.markdown(f"**🙋 You:** {request}")
st.markdown(f"**🌿 Chatbot:** {response}")
st.markdown("---")
if st.session_state.predicted_label:
with st.form(key='question_form', clear_on_submit=True):
user_question = st.text_input("Ask a question about the plant leaf:")
if st.form_submit_button("Submit") and user_question:
with st.spinner("Generating answer..."):
response = info_retriever.get_information(user_question)
st.session_state.requests.append(user_question)
st.session_state.responses.append(response)
st.rerun()
if __name__ == "__main__":
main()