Skip to content

Commit

Permalink
[docs] Add ViLT model code usage example
Browse files Browse the repository at this point in the history
Add an example of running ViLT from code outside
of MMF cli at the end of the ViLT tutorial.
Example is of ViLT vqa on raw image and text.

ghstack-source-id: 4f7804ebe9ef9088197cbb83211e815219ecdfd8
Pull Request resolved: #1179
  • Loading branch information
Ryan-Qiyu-Jiang committed Dec 16, 2021
1 parent d412c25 commit 3205392
Showing 1 changed file with 100 additions and 0 deletions.
100 changes: 100 additions & 0 deletions website/docs/projects/vilt.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,103 @@ To pretrain a ViLT model from scratch on the COCO dataset,
```
mmf_run config=projects/vilt/configs/masked_coco/pretrain.yaml run_type=train_val dataset=masked_coco model=vilt
```

## Using the ViLT model from code
Here is an example of running the ViLT model from code, to do visual question answering (vqa) on a raw image and text.
The forward pass takes ~15ms which is very fast compared to UNITER's ~600ms.

```python
from argparse import Namespace

import torch
from mmf.common.sample import SampleList
from mmf.datasets.processors.bert_processors import VILTTextTokenizer
from mmf.datasets.processors.image_processors import VILTImageProcessor
from mmf.utils.build import build_model
from mmf.utils.configuration import Configuration, load_yaml
from mmf.utils.general import get_current_device
from mmf.utils.text import VocabDict
from omegaconf import OmegaConf
from PIL import Image
```

A way to make model configs and instantiate the ViLT model.
```python
# make model config for vilt vqa2
model_name = "vilt"
config_args = Namespace(
config_override=None,
opts=["model=vilt", "dataset=vqa2", "config=configs/defaults.yaml"],
)
default_config = Configuration(config_args).get_config()
model_vqa_config = load_yaml(
"/private/home/your/path/to/mmf/projects/vilt/configs/vqa2/defaults.yaml"
)
config = OmegaConf.merge(default_config, model_vqa_config)
OmegaConf.resolve(config)
model_config = config.model_config[model_name]
model_config.model = model_name
vilt_model = build_model(model_config)
```

Load model weights, `model_checkpoint_path` is the model checkpoint downloaded at model zoo path `vilt.vqa`,
with current url `s3://dl.fbaipublicfiles.com/mmf/data/models/vilt/vilt.finetuned.vqa2.tar.gz`
```python
# build model and load weights
model_checkpoint_path = './vilt_vqa2.pth'
state_dict = torch.load(model_checkpoint_path)
vilt_model.load_state_dict(state_dict, strict=False)
vilt_model.eval()
vilt_model = vilt_model.to(get_current_device())
```

Prepare input image and text.
This example is using an image of a man with a hat kissing his daughter.
The text is the question posed to the ViLT model for visual question answering.
```python
# get image input
image_processor = VILTImageProcessor({"size": [384, 384]})
image_path = "./kissing_image.jpg"
raw_img = Image.open(image_path).convert("RGB")
image = image_processor(raw_img)

# get text input
text_tokenizer = VILTTextTokenizer({})
question = "What is on his head?"
processed_text_dict = text_tokenizer({"text": question})
```

Wrap everything up in a sample list as expected by the ViLT BaseModel.
```python
# make batch inputs
sample_dict = {**processed_text_dict, "image": image}
sample_dict = {
k: v.unsqueeze(0) for k, v in sample_dict.items() if isinstance(v, torch.Tensor)
}
sample_dict["targets"] = torch.zeros((1, 3129))
sample_dict["targets"][0,1358] = 1
sample_dict["dataset_name"] = "vqa2"
sample_dict["dataset_type"] = "test"
sample_list = SampleList(sample_dict).to(get_current_device())
```

Load the vqa answer -> word string map to understand what it says!
Currently file url at `s3://dl.fbaipublicfiles.com/mmf/data/datasets/vqa2/defaults/extras/vocabs/answers_vqa.txt`
```python
# load vqa2 id -> answers
vocab_file_path = "/private/home/path/to/answers_vqa.txt"
answer_vocab = VocabDict(vocab_file_path)
```

And heres the part you've been waiting for!
```python
# do prediction
with torch.no_grad():
vqa_logits = vilt_model(sample_list)["scores"]
answer_id = vqa_logits.argmax().item()
answer = answer_vocab.idx2word(answer_id)
print(chr(27) + "[2J") # clear the terminal
print(f"{question}: {answer}")
```

Expected output `What is on his head?: hat`

0 comments on commit 3205392

Please sign in to comment.