-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
261 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
--- | ||
title: "Image Classification with CLIP" | ||
description: "Deploy a CLIP model to classify images" | ||
--- | ||
|
||
|
||
<Card | ||
title="View on Github" | ||
icon="github" href="https://github.com/basetenlabs/truss-examples-2/tree/main/2_image_classification/clip"> | ||
</Card> | ||
|
||
In this example, we create a Truss that uses [CLIP](https://openai.com/research/clip) to classify images, | ||
using some pre-defined labels. The input to this Truss will be an image, the output will be a classification. | ||
|
||
One of the major things to note about this example is that since the inputs are images, we need to have | ||
some mechanism for downloading the image. To accomplish this, we have the user pass a downloadable URL to | ||
the Truss, and in the Truss code, download the image. To do this efficiently, we will make use of the | ||
`preprocess` method in Truss. | ||
|
||
# Set up imports and constants | ||
|
||
For our CLIP Truss, we will be using the Hugging Face transformers library, as well as | ||
`pillow` for image processing. | ||
|
||
```python model/model.py | ||
import requests | ||
from typing import Dict | ||
from PIL import Image | ||
from transformers import CLIPProcessor, CLIPModel | ||
|
||
``` | ||
This is the CLIP model from Hugging Face that we will use for this example. | ||
|
||
```python model/model.py | ||
CHECKPOINT = "openai/clip-vit-base-patch32" | ||
|
||
``` | ||
# Define the Truss | ||
|
||
In the `load` method, we load in the pretrained CLIP model from the | ||
Hugging Face checkpoint specified above. | ||
|
||
```python model/model.py | ||
class Model: | ||
def __init__(self, **kwargs) -> None: | ||
self._processor = None | ||
self._model = None | ||
|
||
def load(self): | ||
""" | ||
Loads the CLIP model and processor checkpoints. | ||
""" | ||
self._model = CLIPModel.from_pretrained(CHECKPOINT) | ||
self._processor = CLIPProcessor.from_pretrained(CHECKPOINT) | ||
|
||
``` | ||
In the `preprocess` method, we download the image from the url and preprocess it. | ||
This method is a part of the Truss class, and is designed to be used for any logic | ||
involving IO, like in this case, downloading an image. | ||
|
||
It is called before the predict method in a separate thread, and is not subject to the same | ||
concurrency limits as the predict method, so can be called many times in parallel. | ||
This makes it such that the predict method is not unnecessarily blocked on IO-bound | ||
tasks, and helps improve the throughput of the Truss. See our [guide to concurrency](../guides/concurrency) | ||
for more info. | ||
|
||
```python model/model.py | ||
def preprocess(self, request: Dict) -> Dict: | ||
|
||
image = Image.open(requests.get(request.pop("url"), stream=True).raw) | ||
request["inputs"] = self._processor( | ||
text=["a photo of a cat", "a photo of a dog"], # Define preset labels to use | ||
images=image, | ||
return_tensors="pt", | ||
padding=True | ||
) | ||
return request | ||
|
||
``` | ||
The `predict` method performs the actual inference, and outputs a probability associated | ||
with each of the labels defined earlier. | ||
|
||
```python model/model.py | ||
def predict(self, request: Dict) -> Dict: | ||
""" | ||
This performs the actual classification. The predict method is subject to | ||
the predict concurrency constraints. | ||
""" | ||
outputs = self._model(**request["inputs"]) | ||
logits_per_image = outputs.logits_per_image | ||
return logits_per_image.softmax(dim=1).tolist() | ||
``` | ||
|
||
# Set up the config.yaml | ||
|
||
The main section that needs to be filled out | ||
to run CLIP is the `requirements` section, where we need | ||
to include `transformers`, for the model pipeline, and `pillow`, | ||
for image processing. | ||
|
||
```yaml config.yaml | ||
model_name: clip-example | ||
requirements: | ||
- transformers==4.32.0 | ||
- pillow==10.0.0 | ||
- torch==2.0.1 | ||
model_metadata: | ||
example_model_input: {"url": "https://images.pexels.com/photos/1170986/pexels-photo-1170986.jpeg?auto=compress&cs=tinysrgb&w=1600"} | ||
resources: | ||
cpu: "3" | ||
memory: 14Gi | ||
use_gpu: true | ||
accelerator: A10G | ||
``` | ||
# Deploy the model | ||
Deploy the CLIP model like you would other Trusses, with: | ||
```bash | ||
$ truss push | ||
``` | ||
You can then invoke the model with: | ||
```bash | ||
$ truss predict -d '{"image_url": "https://source.unsplash.com/gKXKBY-C-Dk/300x300""]}' --published | ||
``` | ||
|
||
<RequestExample> | ||
```python model/model.py | ||
import requests | ||
from typing import Dict | ||
from PIL import Image | ||
from transformers import CLIPProcessor, CLIPModel | ||
|
||
CHECKPOINT = "openai/clip-vit-base-patch32" | ||
|
||
class Model: | ||
def __init__(self, **kwargs) -> None: | ||
self._processor = None | ||
self._model = None | ||
|
||
def load(self): | ||
""" | ||
Loads the CLIP model and processor checkpoints. | ||
""" | ||
self._model = CLIPModel.from_pretrained(CHECKPOINT) | ||
self._processor = CLIPProcessor.from_pretrained(CHECKPOINT) | ||
|
||
def preprocess(self, request: Dict) -> Dict: | ||
|
||
image = Image.open(requests.get(request.pop("url"), stream=True).raw) | ||
request["inputs"] = self._processor( | ||
text=["a photo of a cat", "a photo of a dog"], # Define preset labels to use | ||
images=image, | ||
return_tensors="pt", | ||
padding=True | ||
) | ||
return request | ||
|
||
def predict(self, request: Dict) -> Dict: | ||
""" | ||
This performs the actual classification. The predict method is subject to | ||
the predict concurrency constraints. | ||
""" | ||
outputs = self._model(**request["inputs"]) | ||
logits_per_image = outputs.logits_per_image | ||
return logits_per_image.softmax(dim=1).tolist() | ||
``` | ||
```yaml config.yaml | ||
model_name: clip-example | ||
requirements: | ||
- transformers==4.32.0 | ||
- pillow==10.0.0 | ||
- torch==2.0.1 | ||
model_metadata: | ||
example_model_input: {"url": "https://images.pexels.com/photos/1170986/pexels-photo-1170986.jpeg?auto=compress&cs=tinysrgb&w=1600"} | ||
resources: | ||
cpu: "3" | ||
memory: 14Gi | ||
use_gpu: true | ||
accelerator: A10G | ||
``` | ||
</RequestExample> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.