diff --git a/app.py b/app.py index ca578da..f24dbfe 100755 --- a/app.py +++ b/app.py @@ -2,9 +2,8 @@ from os import getenv from dotenv import load_dotenv -from autotagger import Autotagger +from autotagger import Autotagger, read_image from base64 import b64encode -from fastai.vision.core import PILImage from flask import Flask, request, render_template, jsonify, abort from werkzeug.exceptions import HTTPException import torch @@ -33,7 +32,7 @@ def evaluate(): output = request.values.get("format", "html") limit = int(request.values.get("limit", 50)) - images = [PILImage.create(file) for file in files] + images = [read_image(file) for file in files] predictions = autotagger.predict(images, threshold=threshold, limit=limit) if output == "html": diff --git a/autotag b/autotag index ab16621..8f4e46a 100755 --- a/autotag +++ b/autotag @@ -5,8 +5,7 @@ import click import itertools import logging import PIL -from fastai.vision.core import PILImage -from autotagger import Autotagger +from autotagger import Autotagger, read_image from pathlib import Path from more_itertools import ichunked @@ -68,7 +67,7 @@ def recurse_dir(directory): def open_image(filepath): try: with click.open_file(filepath, "rb") as file: - return (filepath, PILImage.create(file)) + return (filepath, read_image(file)) except PIL.UnidentifiedImageError as err: logging.warning(f"Skipped {filepath} (not an image)") return None diff --git a/autotagger/__init__.py b/autotagger/__init__.py index b5a3ef3..07f2650 100644 --- a/autotagger/__init__.py +++ b/autotagger/__init__.py @@ -1 +1 @@ -from .autotagger import Autotagger +from .autotagger import Autotagger, read_image diff --git a/autotagger/autotagger.py b/autotagger/autotagger.py index 01d7027..88dbd41 100644 --- a/autotagger/autotagger.py +++ b/autotagger/autotagger.py @@ -1,45 +1,112 @@ -from fastbook import * -from pandas import DataFrame, read_csv -from fastai.imports import noop -from fastai.callback.progress import ProgressCallback +from fastbook import create_timm_model +import pandas as pd +from pandas import DataFrame import timm import sys +import torch +from PIL import Image +import torchvision.transforms as transforms -class Autotagger: - def __init__(self, model_path="models/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"): - self.model_path = model_path - self.learn = self.init_model(data_path=data_path, tags_path=tags_path, model_path=model_path) - - def init_model(self, model_path="model/model.pth", data_path="test/tags.csv.gz", tags_path="data/tags.json"): - df = read_csv(data_path) - vocab = json.load(open(tags_path)) +# https://github.com/fastai/fastai/blob/176accfd5ae929d73d183d596c7155d3a9401f2f/fastai/vision/core.py#L96 +# load image and copy to new PIL Image object +# allows removal of fastai dep +def read_image(file): + im = Image.open(file) + im.load() + im = im._new(im.im) + return im - dblock = DataBlock( - blocks=(ImageBlock, MultiCategoryBlock(vocab=vocab)), - get_x = lambda df: Path("test") / df["filename"], - get_y = lambda df: df["tags"].split(" "), - item_tfms = Resize(224, method = ResizeMethod.Squish), - batch_tfms = [RandomErasing()] - ) - - dls = dblock.dataloaders(df) - learn = vision_learner(dls, "resnet152", pretrained=False) - model_file = open(model_path, "rb") - learn.load(model_file, with_opt=False) - learn.remove_cb(ProgressCallback) - learn.logger = noop +# take in a single string denoting file path, a single PIL Image instance, +# or a list of either or a combination and handle them using a map-style dataset +class InferenceDataset(torch.utils.data.Dataset): + def __init__(self, files, transform=None): + if isinstance(files, (list, tuple)): + self.files = files + else: + self.files = [files] + + self.transform = transform + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + image = self.files[index] + + # file path case + if isinstance(image, str): + image = Image.open(image) + + assert isinstance(image, Image.Image), "Dataset got invalid type, supported types: singular or list of the following: path as a string, PIL Image" + + # check if file valid + image.load() + + # fill transparent backgorunds with white and convert to RGB + image = image.convert("RGBA") + + # may not replicate behavior of old impl + color = (255,255,255) + background = Image.new('RGB', image.size, color) + background.paste(image, mask=image.split()[3]) + image = background + + if self.transform: image = self.transform(image) + + return image + +class Autotagger: + def __init__(self, model_path = "models/model.pth", tags_path="data/tags.json"): + + # load tags + self.classes = pd.read_json(tags_path) + + # instantiate fastai model + self.model,_ = create_timm_model("resnet152", len(self.classes), pretrained=False) - return learn + # load weights + self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + # set to eval, script and optimize for inference (~2.5x speedup) + # trade off init time for faster inference, scripting/tracing is slow + self.model = self.model.eval() + + # depending on what models are used in the future, use either script or trace + # can't script due to fastai model defn, need to use trace + #self.model = torch.jit.script(self.model) + self.model = torch.jit.trace(self.model, torch.randn(1, 3, 224, 224)) + self.model = torch.jit.optimize_for_inference(self.model) + def predict(self, files, threshold=0.01, limit=50, bs=64): if not files: return - - dl = self.learn.dls.test_dl(files, bs=bs) - batch, _ = self.learn.get_preds(dl=dl) - - for scores in batch: - df = DataFrame({ "tag": self.learn.dls.vocab, "score": scores }) - df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit) - tags = dict(zip(df.tag, df.score)) - yield tags + + # instantiate dataset using files + dataset = InferenceDataset( + files, + transform=transforms.Compose([ + transforms.Resize((224,224)), + transforms.ToTensor(), + ]) + ) + + # create a dataloader, if calling predict with a large batch, + # the input is already split into bs chunks, may make more sense to + # call create a dl with bs of 1, may save memory/reduce latency + # depending on inputs and use case (autotag with 1 file vs list of files) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size = bs, + shuffle=False, + drop_last=False + ) + + for batch in dataloader: + preds = self.model(batch).sigmoid() + for scores in preds: + df = DataFrame({ "tag": self.classes[0], "score": scores }) + df = df[df.score >= threshold].sort_values("score", ascending=False).head(limit) + tags = dict(zip(df.tag, df.score)) + yield tags + + \ No newline at end of file