Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Add ViLT model code usage example #1179

Open
wants to merge 1 commit into
base: gh/ryan-qiyu-jiang/42/base
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions website/docs/projects/vilt.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,103 @@ To pretrain a ViLT model from scratch on the COCO dataset,
```
mmf_run config=projects/vilt/configs/masked_coco/pretrain.yaml run_type=train_val dataset=masked_coco model=vilt
```

## Using the ViLT model from code
Here is an example of running the ViLT model from code, to do visual question answering (vqa) on a raw image and text.
The forward pass takes ~15ms which is very fast compared to UNITER's ~600ms.

```python
from argparse import Namespace

import torch
from mmf.common.sample import SampleList
from mmf.datasets.processors.bert_processors import VILTTextTokenizer
from mmf.datasets.processors.image_processors import VILTImageProcessor
from mmf.utils.build import build_model
from mmf.utils.configuration import Configuration, load_yaml
from mmf.utils.general import get_current_device
from mmf.utils.text import VocabDict
from omegaconf import OmegaConf
from PIL import Image
```

A way to make model configs and instantiate the ViLT model.
```python
# make model config for vilt vqa2
model_name = "vilt"
config_args = Namespace(
config_override=None,
opts=["model=vilt", "dataset=vqa2", "config=configs/defaults.yaml"],
)
default_config = Configuration(config_args).get_config()
model_vqa_config = load_yaml(
"/private/home/your/path/to/mmf/projects/vilt/configs/vqa2/defaults.yaml"
)
config = OmegaConf.merge(default_config, model_vqa_config)
OmegaConf.resolve(config)
model_config = config.model_config[model_name]
model_config.model = model_name
vilt_model = build_model(model_config)
```

Load model weights, `model_checkpoint_path` is the model checkpoint downloaded at model zoo path `vilt.vqa`,
with current url `s3://dl.fbaipublicfiles.com/mmf/data/models/vilt/vilt.finetuned.vqa2.tar.gz`
```python
# build model and load weights
model_checkpoint_path = './vilt_vqa2.pth'
state_dict = torch.load(model_checkpoint_path)
vilt_model.load_state_dict(state_dict, strict=False)
vilt_model.eval()
vilt_model = vilt_model.to(get_current_device())
```

Prepare input image and text.
This example is using an image of a man with a hat kissing his daughter.
The text is the question posed to the ViLT model for visual question answering.
```python
# get image input
image_processor = VILTImageProcessor({"size": [384, 384]})
image_path = "./kissing_image.jpg"
raw_img = Image.open(image_path).convert("RGB")
image = image_processor(raw_img)

# get text input
text_tokenizer = VILTTextTokenizer({})
question = "What is on his head?"
processed_text_dict = text_tokenizer({"text": question})
```

Wrap everything up in a sample list as expected by the ViLT BaseModel.
```python
# make batch inputs
sample_dict = {**processed_text_dict, "image": image}
sample_dict = {
k: v.unsqueeze(0) for k, v in sample_dict.items() if isinstance(v, torch.Tensor)
}
sample_dict["targets"] = torch.zeros((1, 3129))
sample_dict["targets"][0,1358] = 1
sample_dict["dataset_name"] = "vqa2"
sample_dict["dataset_type"] = "test"
sample_list = SampleList(sample_dict).to(get_current_device())
```

Load the vqa answer -> word string map to understand what it says!
Currently file url at `s3://dl.fbaipublicfiles.com/mmf/data/datasets/vqa2/defaults/extras/vocabs/answers_vqa.txt`
```python
# load vqa2 id -> answers
vocab_file_path = "/private/home/path/to/answers_vqa.txt"
answer_vocab = VocabDict(vocab_file_path)
```

And heres the part you've been waiting for!
```python
# do prediction
with torch.no_grad():
vqa_logits = vilt_model(sample_list)["scores"]
answer_id = vqa_logits.argmax().item()
answer = answer_vocab.idx2word(answer_id)
print(chr(27) + "[2J") # clear the terminal
print(f"{question}: {answer}")
```

Expected output `What is on his head?: hat`