-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrayserve_model.py
113 lines (96 loc) · 3.34 KB
/
rayserve_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from ray import serve
from fastapi import FastAPI, HTTPException
from google.cloud import storage, bigquery
import tensorflow as tf
# app version
__version__ = '0.1.2'
# Constants
BUCKET_NAME = 'europe-central2-rso-ml-airf-05c3abe0-bucket'
MODEL_DIR = 'models'
PROJECT_ID = 'balmy-apogee-404909'
DATASET_ID = 'weather_prediction'
WEATHER_TABLE_ID = 'weather_history_LJ'
# FastAPI app
app = FastAPI()
# Function to get the best model name
def get_best_model_name():
storage_client = storage.Client()
bucket = storage_client.bucket(BUCKET_NAME)
blobs = list(bucket.list_blobs(prefix=MODEL_DIR + '/'))
best_model_file = None
for blob in blobs:
if blob.name.endswith('.best'):
best_model_file = blob.name
break
if best_model_file:
model_name = best_model_file.split('/')[-1].replace('.best', '.h5')
return model_name
else:
raise ValueError("Best model file not found")
def get_data():
try:
client = bigquery.Client(project=PROJECT_ID)
query = f"""
SELECT *
FROM `{PROJECT_ID}.{DATASET_ID}.{WEATHER_TABLE_ID}`
ORDER BY time
LIMIT 1
"""
query_job = client.query(query)
df = query_job.to_dataframe()
df.set_index('time', inplace=True)
df = df.astype('float32')
return df
except Exception as e:
return None
# Ray Serve deployment
@serve.deployment
@serve.ingress(app)
class ModelPredictor:
def __init__(self):
self.model = None
self.model_name = None
self.load_model()
def load_model(self):
best_model_name = get_best_model_name()
self.model_name = best_model_name
model_path = f"gs://{BUCKET_NAME}/{MODEL_DIR}/{best_model_name}"
self.model = tf.keras.models.load_model(model_path)
@app.get("/")
async def root(self):
if self.model is None:
raise HTTPException(status_code=500, detail="Model is not loaded")
return f"Hello, we are using model {self.model_name}!"
@app.get("/predict")
async def predict(self):
data = get_data()
if data is None:
raise HTTPException(status_code=500, detail="Data is not loaded")
if self.model is None:
raise HTTPException(status_code=500, detail="Model is not loaded")
prediction = self.model.predict(data)
temp_prediction, precip_prediction = prediction
return {
'temp_predict': temp_prediction.tolist()[0],
'precip_predict': precip_prediction.tolist()[0]
}
# @app.get("/summary")
# async def summary(self):
# stringlist = []
# self.model.summary(print_fn=lambda x: stringlist.append(x))
# short_model_summary = "\n".join(stringlist)
# return short_model_summary
@app.get("/summary")
async def summary(self):
if self.model is None:
raise HTTPException(status_code=500, detail="Model is not loaded")
stringlist = []
self.model.summary(print_fn=lambda x: stringlist.append(x))
return stringlist
@app.get("/health")
async def health(self):
if self.model:
return "Model is loaded"
else:
raise HTTPException(status_code=500, detail="Model is not loaded")
model_predictor = ModelPredictor.bind()