-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinference_api.py
61 lines (49 loc) · 1.79 KB
/
inference_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import os
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
# BaseModel from Pydantic is used to define data objects.
from pydantic import BaseModel, Field
from starter.train_model import online_inference
app = FastAPI()
CAT_FEATURES = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"native-country",
]
class RowData(BaseModel):
age: int = Field(..., example=32)
workclass: str = Field(..., example="Private")
fnlwgt: int = Field(..., example=205019)
education: str = Field(..., example="Assoc-acdm")
education_num: int = Field(..., example=12)
marital_status: str = Field(..., example="Never-married")
occupation: str = Field(..., example="Sales")
relationship: str = Field(..., example="Not-in-family")
race: str = Field(..., example="Black")
sex: str = Field(..., example="Male")
capital_gain: int = Field(..., example=0)
capital_loss: int = Field(..., example=0)
hours_per_week: int = Field(..., example=50)
native_country: str = Field(..., example="United-States")
if "DYNO" in os.environ and os.path.isdir(".dvc"):
os.system("dvc config core.no_scm true")
os.system('rm -rf .dvc/cache')
os.system('rm -rf .dvc/tmp/lock')
os.system('dvc config core.hardlink_lock true')
if os.system("dvc pull -q") != 0:
exit("dvc pull failed")
os.system("rm -rf .dvc .apt/usr/lib/dvc")
@app.get("/")
def home():
return {"Hello": "Welcome to project 3!"}
@app.post('/inference')
async def predict_income(inputrow: RowData):
row_dict = jsonable_encoder(inputrow)
model_path = 'model/random_forest_model_with_encoder_and_lb.pkl'
prediction = online_inference(row_dict, model_path, CAT_FEATURES)
return {"income class": prediction}