-
Notifications
You must be signed in to change notification settings - Fork 0
/
streamlit.py
114 lines (90 loc) · 4.06 KB
/
streamlit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import requests
import os
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
def classify_images(files):
"""Send images to FastAPI service and return predictions."""
try:
# Construct the full FastAPI URL for the /predict endpoint
API_URL = os.path.join(os.getenv("API_URL", "http://127.0.0.1:8000"), "predict")
# Prepare the file objects for sending in the POST request
file_data = [("files", (file.name, file, file.type)) for file in files]
# Make POST request to the FastAPI service with the files
response = requests.post(API_URL, files=file_data, timeout=60) # 60s timeout for large files
response.raise_for_status() # Raise an error for bad responses
# Extract and return the predictions from the response JSON
return response.json().get("predictions", [])
except requests.exceptions.RequestException as e:
st.error(f"Error connecting to the server: {str(e)}")
return None
def display_predictions(predictions, uploaded_files):
"""Display predictions with a clear layout."""
st.markdown("### Classification Results")
for idx, result in enumerate(predictions):
st.markdown("---") # Add a separator between predictions
col1, col2 = st.columns([1, 2])
with col1:
# Display the image from uploaded files
st.image(uploaded_files[idx], caption=result['filename'], use_column_width=True)
with col2:
pneumothorax_prob = result['pneumothorax_prob']
diagnosis = result['diagnosis']
# Display prediction info
st.markdown(f"**Filename:** `{result['filename']}`")
st.markdown(f"**Pneumothorax Probability:** `{pneumothorax_prob:.2f}`")
# Display progress bar for probability
st.progress(pneumothorax_prob)
# Display diagnosis result with color-coded messages
if diagnosis == 'Pneumothorax':
st.error("⚠️ **Positive for Pneumothorax**")
else:
st.success("✅ **No Pneumothorax Detected**")
st.markdown("---") # Final separator
def main():
# Set the page layout
st.set_page_config(page_title="Pneumothorax Classification Service", page_icon="🩺", layout="wide")
# Title and introduction
st.title("🩺 Pneumothorax Classification Service")
st.markdown("""
**Upload your chest X-ray images** below to classify whether the patient has Pneumothorax.
This tool helps healthcare professionals analyze chest X-rays quickly.
""")
# Sidebar: Instructions and file uploader
with st.sidebar:
st.header("Upload Images")
st.markdown("""
**How to use:**
- Upload multiple chest X-ray images.
- Click 'Classify Images' to see the predictions.
""")
uploaded_files = st.file_uploader(
"Choose images (jpg, jpeg, png)",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True
)
# Main section: Display uploaded images or a prompt to upload
if uploaded_files:
st.markdown("### Uploaded Images")
cols = st.columns(4) # Organize images in grid format
for idx, uploaded_file in enumerate(uploaded_files):
with cols[idx % 4]:
st.image(
uploaded_file,
caption=f"Image: {uploaded_file.name}",
use_column_width=True
)
# Button to classify images
if st.button("🚀 Classify Images"):
with st.spinner("Classifying..."):
predictions = classify_images(uploaded_files)
if predictions:
display_predictions(predictions, uploaded_files)
else:
st.error("No predictions returned from the server.")
else:
st.markdown("### No images uploaded yet.")
st.info("Please upload images from the sidebar to start the classification.")
if __name__ == "__main__":
main()