-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathpredict.py
More file actions
45 lines (34 loc) · 1.11 KB
/
predict.py
File metadata and controls
45 lines (34 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from pathlib import Path
import numpy as np
from fastai.vision.all import load_learner, pd
PROJECT_DIR = Path("kitchenware_classifier")
DATA_DIR = PROJECT_DIR / "data"
TEST_FILE = DATA_DIR / "test.csv"
IMG_DIR = DATA_DIR / "images"
TEST_DF = pd.read_csv(TEST_FILE)
TEST_DF["image"] = TEST_DF["Id"].map(lambda x: f"{IMG_DIR}/{x:0>4}.jpg")
MODEL_FILE = Path("fastai_model.pkl")
EXPORT_FILE = Path("submission.csv")
def process_images(df, model):
learn = load_learner(model)
dls = learn.dls
tst_dl = dls.test_dl(df.image)
return learn.tta(dl=tst_dl), dls
def generate_submission(tta, dls):
tta_preds, _ = tta
idxs = tta_preds.argmax(dim=1)
vocab = np.array(dls.vocab)
sub = pd.read_csv(TEST_FILE)
sub["label"] = vocab[idxs]
sub.to_csv(EXPORT_FILE, index=False)
return sub
def main():
if not MODEL_FILE.exists():
print(f"Model {MODEL_FILE} not found!")
print("Please run train.py first.")
else:
tta, dls = process_images(TEST_DF, MODEL_FILE)
sub = generate_submission(tta, dls)
print(sub.head())
if __name__ == "__main__":
main()