-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
58 lines (49 loc) · 1.65 KB
/
main.py
File metadata and controls
58 lines (49 loc) · 1.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import numpy as np
from PIL import Image
import gradio as gr
import mlflow
import mlflow.keras
from dotenv import load_dotenv
load_dotenv()
MLFLOW_TRACKING_URI = "http://localhost:5000"
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
ARTIFACT_URI = os.getenv("ARTIFACT_URI")
print(f"Attempting to load model from MLflow URI: {ARTIFACT_URI}")
try:
model = mlflow.keras.load_model(ARTIFACT_URI)
print("Model loaded successfully.")
except Exception as e:
print(f"ERROR: Failed to load model.\n{e}")
model = None
CIFAR10_CLASSES = [
"airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"
]
def classify_cifar(image):
if model is None:
return {"Error": 1.0}
img = Image.fromarray(image.astype("uint8")).resize((32, 32))
img_array = np.asarray(img).astype("float32") / 255.0
img_array = np.expand_dims(img_array, axis=0)
preds = model.predict(img_array, verbose=0)[0]
return {CIFAR10_CLASSES[i]: float(preds[i]) for i in range(10)}
iface = gr.Interface(
fn=classify_cifar,
inputs=gr.Image(
label="Upload an Image",
type="numpy",
sources=["upload", "webcam"],
image_mode="RGB",
height=256,
),
outputs=gr.Label(num_top_classes=3),
title="🧠 CIFAR-10 Image Classifier",
description="Upload or capture an image to classify it.",
examples=[
["https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"],
["https://upload.wikimedia.org/wikipedia/commons/thumb/5/50/Voyager_of_the_Seas_Costa_Maya_2023.jpg/2560px-Voyager_of_the_Seas_Costa_Maya_2023.jpg"],
],
)
if __name__ == "__main__":
iface.launch()