-
Notifications
You must be signed in to change notification settings - Fork 2
/
app.py
75 lines (62 loc) · 2.44 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from pathlib import Path
from shutil import rmtree
from time import sleep
import pandas as pd
import streamlit as st
from src.entity.config_entity import ModelPusherConfig, TrainingPipelineConfig
from src.pipeline.batch_prediction import start_batch_prediction
from src.pipeline.training_pipeline import start_training_pipeline
from src.predictor import ModelResolver
input_fp = Path('input.csv')
if input_fp.exists():
input_fp.unlink()
st.title(':red[APS Fault Detection System]')
msg = st.empty()
# --- --- Variables --- --- #
latest_dir_path = ModelResolver().get_latest_dir_path()
saved_model_dir = Path(
ModelPusherConfig(TrainingPipelineConfig()).saved_model_dir)
# Train model button
if latest_dir_path is None:
with st.spinner('Training in progress...'):
if st.button('Train Model', use_container_width=True):
start_training_pipeline()
msg.success('Model Training Completed.', icon='✅')
st.experimental_rerun()
else:
msg.warning('Model already trained. Start your Batch Prediction.',
icon='🤖')
with st.spinner('Deletion in progress...'):
if st.button('Delete Pre-Trained Model', use_container_width=True):
rmtree(saved_model_dir) # Delete saved_model_dir tree
rmtree(Path('logs')) # Delete logs folder
msg.success('Pre-Trained model deleted.', icon='✅')
sleep(2)
st.experimental_rerun()
# Upload CSV file button
with st.form("upload_form"):
st.subheader(':green[Batch Prediction]')
uploaded_file = st.file_uploader(
"Choose a CSV file to upload", type=["csv"],
)
if st.form_submit_button("Submit") and uploaded_file:
if (not saved_model_dir.exists() or
len(list(saved_model_dir.iterdir())) == 0):
msg.warning('Model is not trained yet. Please train model.')
st.stop()
df = pd.read_csv(uploaded_file)
df.to_csv(input_fp, index=False)
msg.success('Your data is submitted successfully.')
# Download button
if input_fp.exists():
prediction_fp = start_batch_prediction(input_fp)
df = pd.read_csv(prediction_fp)
msg.success('Prediction Completed. Download your file.')
if st.download_button(
label="Download Prediction DataFrame",
file_name="prediction.csv",
data=df.to_csv(index=False),
mime="csv",
use_container_width=True,
):
st.experimental_rerun()