-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
75 lines (57 loc) · 2.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import requests
import streamlit as st
import pandas as pd
from src.settings import get_settings
from src.data_models import TimeSeries
from src.logic import calculate_ts_metrics
settings = get_settings()
# Set page title and icon
st.set_page_config(
page_title="Time series prediction app",
page_icon=":8ball:",
menu_items={
"About": "https://mvrck.space/",
"Report a bug": "https://github.com/mvrck96/ml-service-frontend/issues/new"
}
)
st.title("📈 Time series prediction UI")
csv_file = st.file_uploader("Upload csv file with time series data", type=["csv"])
# If file uploaded
if csv_file:
with st.form("main_form"):
st.header("Prediction parameters")
df = pd.read_csv(csv_file)
feature_name = st.selectbox("Pick a feature to predict", options=df.columns)
smoothing_level = st.slider(
"Select smoothing level", min_value=0.0, max_value=1.0, step=0.05, value=0.5
)
# Assemble payload for POST request
payload = TimeSeries(
feature=feature_name,
data=list(df[feature_name].values),
smoothing_level=smoothing_level,
).json()
# After button pressed
if st.form_submit_button("Calculate metrics and make prediction"):
session = requests.Session()
st.subheader("Selected time series")
st.line_chart(data=df, y=feature_name)
st.subheader("Calculated metrics")
metrics = calculate_ts_metrics(df[feature_name].values)
m1, m2 = st.columns(2)
m1.metric(label="Mean", value=metrics["mean"])
m2.metric(label="Std", value=metrics["std"])
st.markdown("---")
st.subheader("Prediction result")
try:
# Send request
response = session.post(settings.ml_service_url, payload).json()
pred = round(response["predicted_value"], 3)
c11, c12 = st.columns(2)
c11.text("Feature name: ")
c12.markdown(f"`{response['feature']}`")
c21, c22 = st.columns(2)
c21.text("Predicted value: ")
c22.markdown(f"`{pred}`")
except requests.exceptions.ConnectionError as e:
st.warning("Can't connect to backend service !")